From a97fca2fb71a644e8b028c7058ed1afb6f087555 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Mon, 17 Feb 2025 15:51:24 +0800 Subject: [PATCH 01/37] 1 --- src/diffusers/__init__.py | 2 + .../models/controlnets/controlnet_cogview4.py | 465 ++++++++++++ src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/auto_pipeline.py | 3 +- src/diffusers/pipelines/cogview4/__init__.py | 1 + .../cogview4/pipeline_cogview4_control.py | 713 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 13 + 7 files changed, 1198 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/models/controlnets/controlnet_cogview4.py create mode 100644 src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a9e7c823db41..bfd62cc9fb0f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -289,6 +289,7 @@ "CogVideoXVideoToVideoPipeline", "CogView3PlusPipeline", "CogView4Pipeline", + "CogView4ControlPipeline", "ConsisIDPipeline", "CycleDiffusionPipeline", "FluxControlImg2ImgPipeline", @@ -788,6 +789,7 @@ CogVideoXVideoToVideoPipeline, CogView3PlusPipeline, CogView4Pipeline, + CogView4ControlPipeline, ConsisIDPipeline, CycleDiffusionPipeline, FluxControlImg2ImgPipeline, diff --git a/src/diffusers/models/controlnets/controlnet_cogview4.py b/src/diffusers/models/controlnets/controlnet_cogview4.py new file mode 100644 index 000000000000..5f34c20792db --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_cogview4.py @@ -0,0 +1,465 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import JointTransformerBlock +from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 +from ..embeddings import CombinedTimestepTextProjEmbeddings +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..transformers.transformer_cogview4 import CogView4TransformerBlock,CogView4PatchEmbed +from .controlnet import BaseOutput, zero_module + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class SD3ControlNetOutput(BaseOutput): + controlnet_block_samples: Tuple[torch.Tensor] + + +class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 18, + attention_head_dim: int = 64, + num_attention_heads: int = 18, + joint_attention_dim: int = 4096, + caption_projection_dim: int = 1152, + pooled_projection_dim: int = 2048, + out_channels: int = 16, + pos_embed_max_size: int = 96, + extra_conditioning_channels: int = 0, + dual_attention_layers: Tuple[int, ...] = (), + qk_norm: Optional[str] = None, + pos_embed_type: Optional[str] = "sincos", + use_pos_embed: bool = True, + force_zeros_for_pooled_projection: bool = True, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.pos_embed = CogView4PatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + pos_embed_max_size=pos_embed_max_size, + pos_embed_type=pos_embed_type, + ) + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + if joint_attention_dim is not None: + self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) + + # `attention_head_dim` is doubled to account for the mixing. + # It needs to crafted when we get the actual checkpoints. + self.transformer_blocks = nn.ModuleList( + [ + JointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + context_pre_only=False, + qk_norm=qk_norm, + use_dual_attention=True if i in dual_attention_layers else False, + ) + for i in range(num_layers) + ] + ) + else: + self.context_embedder = None + self.transformer_blocks = nn.ModuleList( + [ + CogView4TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + # controlnet_blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + controlnet_block = nn.Linear(self.inner_dim, self.inner_dim) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + pos_embed_input = CogView4PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels + extra_conditioning_channels, + embed_dim=self.inner_dim, + pos_embed_type=None, + ) + self.pos_embed_input = zero_module(pos_embed_input) + + self.gradient_checkpointing = False + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedJointAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + # Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer + # we should have handled this in conversion script + def _get_pos_embed_from_transformer(self, transformer): + pos_embed = CogView4PatchEmbed( + height=transformer.config.sample_size, + width=transformer.config.sample_size, + patch_size=transformer.config.patch_size, + in_channels=transformer.config.in_channels, + embed_dim=transformer.inner_dim, + pos_embed_max_size=transformer.config.pos_embed_max_size, + ) + pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=True) + return pos_embed + + @classmethod + def from_transformer( + cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True + ): + config = transformer.config + config["num_layers"] = num_layers or config.num_layers + config["extra_conditioning_channels"] = num_extra_conditioning_channels + controlnet = cls.from_config(config) + + if load_weights_from_transformer: + controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) + controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) + controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) + controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) + + controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input) + + return controlnet + + def forward( + self, + hidden_states: torch.FloatTensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`SD3Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + if self.pos_embed is not None and hidden_states.ndim != 4: + raise ValueError("hidden_states must be 4D when pos_embed is used") + + # SD3.5 8b controlnet does not have a `pos_embed`, + # it use the `pos_embed` from the transformer to process input before passing to controlnet + elif self.pos_embed is None and hidden_states.ndim != 3: + raise ValueError("hidden_states must be 3D when pos_embed is not used") + + if self.context_embedder is not None and encoder_hidden_states is None: + raise ValueError("encoder_hidden_states must be provided when context_embedder is used") + # SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states` + elif self.context_embedder is None and encoder_hidden_states is not None: + raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used") + + if self.pos_embed is not None: + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + + temb = self.time_text_embed(timestep, pooled_projections) + + if self.context_embedder is not None: + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # add + hidden_states = hidden_states + self.pos_embed_input(controlnet_cond) + + block_res_samples = () + + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + if self.context_embedder is not None: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + ) + else: + # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states` + hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb) + + else: + if self.context_embedder is not None: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + else: + # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states` + hidden_states = block(hidden_states, temb) + + block_res_samples = block_res_samples + (hidden_states,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + # 6. scaling + controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (controlnet_block_res_samples,) + + return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) + + +class SD3MultiControlNetModel(ModelMixin): + r""" + `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet + + This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be + compatible with `SD3ControlNetModel`. + + Args: + controlnets (`List[SD3ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `SD3ControlNetModel` as a list. + """ + + def __init__(self, controlnets): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + pooled_projections: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[SD3ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + block_samples = controlnet( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + pooled_projections=pooled_projections, + controlnet_cond=image, + conditioning_scale=scale, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + control_block_samples = block_samples + else: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0]) + ] + control_block_samples = (tuple(control_block_samples),) + + return control_block_samples diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 49041086f535..a25bca9c34dc 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -154,7 +154,7 @@ "CogVideoXFunControlPipeline", ] _import_structure["cogview3"] = ["CogView3PlusPipeline"] - _import_structure["cogview4"] = ["CogView4Pipeline"] + _import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"] _import_structure["consisid"] = ["ConsisIDPipeline"] _import_structure["controlnet"].extend( [ @@ -500,7 +500,7 @@ CogVideoXVideoToVideoPipeline, ) from .cogview3 import CogView3PlusPipeline - from .cogview4 import CogView4Pipeline + from .cogview4 import CogView4Pipeline, CogView4ControlPipeline from .consisid import ConsisIDPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 1c38f83a7ef3..2a100ffcaa9e 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -22,7 +22,7 @@ from ..utils import is_sentencepiece_available from .aura_flow import AuraFlowPipeline from .cogview3 import CogView3PlusPipeline -from .cogview4 import CogView4Pipeline +from .cogview4 import CogView4Pipeline, CogView4ControlPipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, @@ -140,6 +140,7 @@ ("lumina2", Lumina2Text2ImgPipeline), ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), + ("cogview4-control", CogView4ControlPipeline), ] ) diff --git a/src/diffusers/pipelines/cogview4/__init__.py b/src/diffusers/pipelines/cogview4/__init__.py index 5a535b3feb4b..531cea7d7c66 100644 --- a/src/diffusers/pipelines/cogview4/__init__.py +++ b/src/diffusers/pipelines/cogview4/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_cogview4"] = ["CogView4Pipeline"] + _import_structure["pipeline_cogview4_control"] = ["CogView4ControlPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py new file mode 100644 index 000000000000..dcfdeaaf8c7c --- /dev/null +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -0,0 +1,713 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import AutoTokenizer, GlmModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...models import AutoencoderKL, CogView4Transformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import CogView4PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogView4Pipeline + + >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image.save("output.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +): + # m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + # b = base_shift - m * base_seq_len + # mu = image_seq_len * m + b + # return mu + + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView4ControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using CogView4. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogView4 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogView4Transformer2DModel`]): + A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: GlmModel, + vae: AutoencoderKL, + transformer: CogView4Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _get_glm_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 1024, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="longest", # not use max length + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + current_length = text_input_ids.shape[1] + pad_length = (16 - (current_length % 16)) % 16 + if pad_length > 0: + pad_ids = torch.full( + (text_input_ids.shape[0], pad_length), + fill_value=self.tokenizer.pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) + prompt_embeds = self.text_encoder( + text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True + ).hidden_states[-2] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `1024`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, device, dtype) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_glm_embeds( + negative_prompt, num_images_per_prompt, max_sequence_length, device, dtype + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 1024, + ) -> Union[CogView4PipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. If not provided, it is set to 1024. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. If not provided it is set to 1024. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + + Examples: + + Returns: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = (height, width) + + # Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Prepare latents + latent_channels = self.transformer.config.in_channels + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + vae_shift_factor = 0 + + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) + + original_size = original_size.repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) + + # Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( + self.transformer.config.patch_size**2 + ) + + timesteps = ( + np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps) + if timesteps is None + else np.array(timesteps) + ) + timesteps = timesteps.astype(np.int64) + sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + _, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu) + timesteps = torch.from_numpy(timesteps).to(device) + + # Denoising loop + transformer_dtype = self.transformer.dtype + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents, control_image], dim=2).to(transformer_dtype) + # latent_model_input = latents.to(transformer_dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CogView4PipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index c853cf8faa55..9b3742127824 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -376,6 +376,19 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class CogView4ControlPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) class ConsisIDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From c30ca7a119f0951c913ff36798876d5714e1a9fa Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Mon, 17 Feb 2025 17:59:09 +0800 Subject: [PATCH 02/37] change to channel 1 --- src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index dcfdeaaf8c7c..8c982eaed05f 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -646,9 +646,7 @@ def __call__( for i, t in enumerate(timesteps): if self.interrupt: continue - - latent_model_input = torch.cat([latents, control_image], dim=2).to(transformer_dtype) - # latent_model_input = latents.to(transformer_dtype) + latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) From 5c25cd2e3954b355f686f1f917df299f88c76f25 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Tue, 18 Feb 2025 14:43:40 +0800 Subject: [PATCH 03/37] cogview4 control training --- examples/cogview4-control/README.md | 201 +++ examples/cogview4-control/requirements.txt | 6 + .../train_control_cogview4.py | 1221 +++++++++++++++++ .../transformers/transformer_cogview4.py | 2 +- 4 files changed, 1429 insertions(+), 1 deletion(-) create mode 100644 examples/cogview4-control/README.md create mode 100644 examples/cogview4-control/requirements.txt create mode 100644 examples/cogview4-control/train_control_cogview4.py diff --git a/examples/cogview4-control/README.md b/examples/cogview4-control/README.md new file mode 100644 index 000000000000..746a99a1a41b --- /dev/null +++ b/examples/cogview4-control/README.md @@ -0,0 +1,201 @@ +# Training CogView4 Control + +This (experimental) example shows how to train Control LoRAs with [CogView4](https://huggingface.co/THUDM/CogView4-6B) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about CogView4 Control family, refer to the following resources: + +To incorporate additional condition latents, we expand the input features of CogView-4 from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `patch_embed` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `CogView4ControlPipeline`. + +> [!NOTE] +> **Gated model** +> +> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: + +```bash +huggingface-cli login +``` + +The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them. + +```bash +accelerate launch train_control_lora_cogview4.py \ + --pretrained_model_name_or_path="THUDM/CogView4-6B" \ + --dataset_name="raulc0399/open_pose_controlnet" \ + --output_dir="pose-control-lora" \ + --mixed_precision="bf16" \ + --train_batch_size=1 \ + --rank=64 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --use_8bit_adam \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=5000 \ + --validation_image="openpose.png" \ + --validation_prompt="A couple, 4k photo, highly detailed" \ + --offload \ + --seed="0" \ + --push_to_hub +``` + +`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png). + +You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`. + +The training script exposes additional CLI args that might be useful to experiment with: + +* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer. +* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading. +* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached. + +### Training with DeepSpeed + +It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed): + +```yaml +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +``` + +And then while launching training, pass the config file: + +```bash +accelerate launch --config_file=CONFIG_FILE.yaml ... +``` + +### Inference + +The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first: + +```bash +pip install controlnet_aux +``` + +And then we are ready: + +```py +from controlnet_aux import OpenposeDetector +from diffusers import CogView4ControlPipeline +from diffusers.utils import load_image +from PIL import Image +import numpy as np +import torch + +pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("...") # change this. + +open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") + +# prepare pose condition. +url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg" +image = load_image(url) +image = open_pose(image, detect_resolution=512, image_resolution=1024) +image = np.array(image)[:, :, ::-1] +image = Image.fromarray(np.uint8(image)) + +prompt = "A couple, 4k photo, highly detailed" + +gen_images = pipe( + prompt=prompt, + control_image=image, + num_inference_steps=50, + joint_attention_kwargs={"scale": 0.9}, + guidance_scale=25., +).images[0] +gen_images.save("output.png") +``` + +## Full fine-tuning + +We provide a non-LoRA version of the training script `train_control_cogview4.py`. Here is an example command: + +```bash +accelerate launch --config_file=accelerate_ds2.yaml train_control_cogview4.py \ + --pretrained_model_name_or_path="THUDM/CogView4-6B" \ + --dataset_name="raulc0399/open_pose_controlnet" \ + --output_dir="pose-control" \ + --mixed_precision="bf16" \ + --train_batch_size=2 \ + --dataloader_num_workers=4 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --use_8bit_adam \ + --proportion_empty_prompts=0.2 \ + --learning_rate=5e-5 \ + --adam_weight_decay=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="cosine" \ + --lr_warmup_steps=1000 \ + --checkpointing_steps=1000 \ + --max_train_steps=10000 \ + --validation_steps=200 \ + --validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \ + --validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \ + --offload \ + --seed="0" \ + --push_to_hub +``` + +Change the `validation_image` and `validation_prompt` as needed. + +For inference, this time, we will run: + +```py +from controlnet_aux import OpenposeDetector +from diffusers import CogView4ControlPipeline, CogView4Transformer2DModel +from diffusers.utils import load_image +from PIL import Image +import numpy as np +import torch + +transformer = CogView4Transformer2DModel.from_pretrained("...") # change this. +pipe = CogView4ControlPipeline.from_pretrained( + "THUDM/CogView4-6B", transformer=transformer, torch_dtype=torch.bfloat16 +).to("cuda") + +open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") + +# prepare pose condition. +url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg" +image = load_image(url) +image = open_pose(image, detect_resolution=512, image_resolution=1024) +image = np.array(image)[:, :, ::-1] +image = Image.fromarray(np.uint8(image)) + +prompt = "A couple, 4k photo, highly detailed" + +gen_images = pipe( + prompt=prompt, + control_image=image, + num_inference_steps=50, + guidance_scale=25., +).images[0] +gen_images.save("output.png") +``` + +## Things to note + +* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗 +* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified. +* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality. \ No newline at end of file diff --git a/examples/cogview4-control/requirements.txt b/examples/cogview4-control/requirements.txt new file mode 100644 index 000000000000..6c5ec2e03f9a --- /dev/null +++ b/examples/cogview4-control/requirements.txt @@ -0,0 +1,6 @@ +transformers==4.47.0 +wandb +torch +torchvision +accelerate==1.2.0 +peft>=0.14.0 diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py new file mode 100644 index 000000000000..8b4b81e7ff27 --- /dev/null +++ b/examples/cogview4-control/train_control_cogview4.py @@ -0,0 +1,1221 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import logging +import math +import os +import shutil +from contextlib import nullcontext +from pathlib import Path + +import accelerate +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm + +import diffusers +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, CogView4ControlPipeline,CogView4Transformer2DModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + free_memory, +) +from diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.33.0.dev0") + +logger = get_logger(__name__) + +NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] + + +def encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype): + pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample() + pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor + return pixel_latents.to(weight_dtype) + + +def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step, is_final_validation=False): + logger.info("Running validation... ") + + if not is_final_validation: + cogview4_transformer = accelerator.unwrap_model(cogview4_transformer) + pipeline = CogView4ControlPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=cogview4_transformer, + torch_dtype=weight_dtype, + ) + else: + transformer = CogView4Transformer2DModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + pipeline = CogView4ControlPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=transformer, + torch_dtype=weight_dtype, + ) + + pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + if is_final_validation or torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype) + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = load_image(validation_image) + # maybe need to inference on 1024 to get a good image + validation_image = validation_image.resize((args.resolution, args.resolution)) + + images = [] + + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt=validation_prompt, + control_image=validation_image, + num_inference_steps=50, + guidance_scale=args.guidance_scale, + generator=generator, + max_sequence_length=512, + height=args.resolution, + width=args.resolution, + ).images[0] + image = image.resize((args.resolution, args.resolution)) + images.append(image) + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + formatted_images = [] + formatted_images.append(np.asarray(validation_image)) + for image in images: + formatted_images.append(np.asarray(image)) + formatted_images = np.stack(formatted_images) + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + + elif tracker.name == "wandb": + formatted_images = [] + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + formatted_images.append(wandb.Image(validation_image, caption="Conditioning")) + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + del pipeline + free_memory() + return image_logs + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + model_description = f""" +# cogview4-control-{repo_id} + +These are Control weights trained on {base_model} with new type of conditioning. +{img_str} + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogView4-6b/blob/main/LICENSE.md) +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "cogview4", + "cogview4-diffusers", + "text-to-image", + "diffusers", + "control", + "diffusers-training", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a CogView4 Control training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="cogview4-control", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the control conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.") + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the control conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=1, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="cogview4_train_control", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + parser.add_argument( + "--jsonl_for_train", + type=str, + default=None, + help="Path to the jsonl file containing the training data.", + ) + parser.add_argument( + "--only_target_transformer_blocks", + action="store_true", + help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=30.0, + help="the guidance scale used for transformer.", + ) + + parser.add_argument( + "--upcast_before_saving", + action="store_true", + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoders to CPU when they are not used.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.jsonl_for_train is None: + raise ValueError("Specify either `--dataset_name` or `--jsonl_for_train`") + + if args.dataset_name is not None and args.jsonl_for_train is not None: + raise ValueError("Specify only one of `--dataset_name` or `--jsonl_for_train`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the cogview4 transformer." + ) + + return args + + +def get_train_dataset(args, accelerator): + dataset = None + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + if args.jsonl_for_train is not None: + # load from json + dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir) + dataset = dataset.flatten_indices() + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + with accelerator.main_process_first(): + train_dataset = dataset["train"].shuffle(seed=args.seed) + if args.max_train_samples is not None: + train_dataset = train_dataset.select(range(args.max_train_samples)) + return train_dataset + + +def prepare_train_dataset(dataset, accelerator): + image_transforms = transforms.Compose( + [ + transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + + def preprocess_train(examples): + images = [ + (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB")) + for image in examples[args.image_column] + ] + images = [image_transforms(image) for image in images] + + conditioning_images = [ + (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB")) + for image in examples[args.conditioning_image_column] + ] + conditioning_images = [image_transforms(image) for image in conditioning_images] + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + + is_caption_list = isinstance(examples[args.caption_column][0], list) + if is_caption_list: + examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]] + else: + examples["captions"] = list(examples[args.caption_column]) + + return examples + + with accelerator.main_process_first(): + dataset = dataset.with_transform(preprocess_train) + + return dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + captions = [example["captions"] for example in examples] + return {"pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "captions": captions} + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_out_dir = Path(args.output_dir, args.logging_dir) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir)) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices. + if torch.backends.mps.is_available(): + logger.info("MPS is enabled. Disabling AMP.") + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + # DEBUG, INFO, WARNING, ERROR, CRITICAL + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load models. We will load the text encoders later in a pipeline to compute + # embeddings. + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + cogview4_transformer = CogView4Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + ) + logger.info("All models loaded successfully") + + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + if not args.only_target_transformer_blocks: + cogview4_transformer.requires_grad_(True) + vae.requires_grad_(False) + + # cast down and move to the CPU + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # let's not move the VAE to the GPU yet. + vae.to(dtype=torch.float32) # keep the VAE in float32. + + # enable image inputs + with torch.no_grad(): + patch_size = cogview4_transformer.config.patch_size + initial_input_channels = cogview4_transformer.config.in_channels * patch_size**2 + new_linear = torch.nn.Linear( + cogview4_transformer.patch_embed.proj.in_features * 2, + cogview4_transformer.patch_embed.proj.out_features, + bias=cogview4_transformer.patch_embed.proj.bias is not None, + dtype=cogview4_transformer.dtype, + device=cogview4_transformer.device, + ) + new_linear.weight.zero_() + new_linear.weight[:, :initial_input_channels].copy_(cogview4_transformer.patch_embed.proj.weight) + if cogview4_transformer.patch_embed.proj.bias is not None: + new_linear.bias.copy_(cogview4_transformer.patch_embed.proj.bias) + cogview4_transformer.patch_embed.proj = new_linear + + assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0) + cogview4_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) + + if args.only_target_transformer_blocks: + cogview4_transformer.patch_embed.proj.requires_grad_(True) + for name, module in cogview4_transformer.named_modules(): + if "transformer_blocks" in name: + module.requires_grad_(True) + else: + module.requirs_grad_(False) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + if isinstance(unwrap_model(model), type(unwrap_model(cogview4_transformer))): + model = unwrap_model(model) + model.save_pretrained(os.path.join(output_dir, "transformer")) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + def load_model_hook(models, input_dir): + transformer_ = None + + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(model), type(unwrap_model(cogview4_transformer))): + transformer_ = model # noqa: F841 + else: + raise ValueError(f"unexpected save model: {unwrap_model(model).__class__}") + + else: + transformer_ = CogView4Transformer2DModel.from_pretrained(input_dir, subfolder="transformer") # noqa: F841 + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + cogview4_transformer.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimization parameters + optimizer = optimizer_class( + cogview4_transformer.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Prepare dataset and dataloader. + train_dataset = get_train_dataset(args, accelerator) + train_dataset = prepare_train_dataset(train_dataset, accelerator) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + # Prepare everything with our `accelerator`. + cogview4_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + cogview4_transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed. + text_encoding_pipeline = CogView4ControlPipeline.from_pretrained( + args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype + ) + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + logger.info(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples: + logger.info("Logging some dataset samples.") + formatted_images = [] + formatted_control_images = [] + all_prompts = [] + for i, batch in enumerate(train_dataloader): + images = (batch["pixel_values"] + 1) / 2 + control_images = (batch["conditioning_pixel_values"] + 1) / 2 + prompts = batch["captions"] + + if len(formatted_images) > 10: + break + + for img, control_img, prompt in zip(images, control_images, prompts): + formatted_images.append(img) + formatted_control_images.append(control_img) + all_prompts.append(prompt) + + logged_artifacts = [] + for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts): + logged_artifacts.append(wandb.Image(control_img, caption="Conditioning")) + logged_artifacts.append(wandb.Image(img, caption=prompt)) + + wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"] + wandb_tracker[0].log({"dataset_samples": logged_artifacts}) + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + cogview4_transformer.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(cogview4_transformer): + # Convert images to latent space + # vae encode + pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype) + control_latents = encode_images( + batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype + ) + if args.offload: + vae.cpu() + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + bsz = pixel_latents.shape[0] + noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype) + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + + # Add noise according for cogview4 + #FIXME: The issue of variable-length training has not been resolved, here it is still extended to the longest one. + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device) + sigmas = noise_scheduler_copy.sigmas[indices].to(device=pixel_latents.device) + captions = batch["captions"] + token_lengths = [len(caption.split()) for caption in captions] + token_per_sample = max(token_lengths) + image_seq_lens = torch.tensor(token_per_sample // patch_size ** 2, dtype=pixel_latents.dtype, device=pixel_latents.device) + mu = torch.sqrt(image_seq_lens / 256) + mu = mu * 0.75 + 0.25 + scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to(dtype=pixel_latents.dtype, device=pixel_latents.device) + scale_factors = scale_factors.view(4, 1, 1, 1) + noisy_model_input = (1.0 - scale_factors) * pixel_latents + scale_factors * noise + concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1) + text_encoding_pipeline = text_encoding_pipeline.to("cuda") + + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, = text_encoding_pipeline.encode_prompt( + captions, None + ) + original_size = (args.resolution, args.resolution) + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device) + + target_size = (args.resolution,args.resolution) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device) + + target_size = target_size.repeat(len(batch["captions"]), 1) + original_size = original_size.repeat(len(batch["captions"]), 1) + + #TODO: Should a parameter be set here for passing? This is not present in Flux. + crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device) + crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1) + + # Predict. + model_pred = cogview4_transformer( + hidden_states=concatenated_noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + # flow-matching loss + target = noise - pixel_latents + + weighting = weighting.unsqueeze(1).unsqueeze(2).unsqueeze(3) # [4, 1, 1, 1] + loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),1) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = cogview4_transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + image_logs = log_validation( + cogview4_transformer=cogview4_transformer, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + cogview4_transformer = unwrap_model(cogview4_transformer) + if args.upcast_before_saving: + cogview4_transformer.to(torch.float32) + cogview4_transformer.save_pretrained(args.output_dir) + + del cogview4_transformer + del text_encoding_pipeline + del vae + free_memory() + + # Run a final round of validation. + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation( + cogview4_transformer=None, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*", "checkpoint-*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index f622791b572f..e608b16b64bc 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -35,7 +35,7 @@ class CogView4PatchEmbed(nn.Module): def __init__( self, in_channels: int = 16, - hidden_size: int = 2560, + hidden_size: int = 4096, patch_size: int = 2, text_hidden_size: int = 4096, ): From 44bfd4c8277931f3584baa89f4cade3726236991 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Tue, 18 Feb 2025 16:01:01 +0800 Subject: [PATCH 04/37] add CacheMixin --- examples/cogview4-control/train_control_cogview4.py | 5 ++--- src/diffusers/models/transformers/transformer_cogview4.py | 5 +++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 8b4b81e7ff27..7d445da3ac67 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -127,7 +127,6 @@ def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step, num_inference_steps=50, guidance_scale=args.guidance_scale, generator=generator, - max_sequence_length=512, height=args.resolution, width=args.resolution, ).images[0] @@ -1075,7 +1074,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): mu = torch.sqrt(image_seq_lens / 256) mu = mu * 0.75 + 0.25 scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to(dtype=pixel_latents.dtype, device=pixel_latents.device) - scale_factors = scale_factors.view(4, 1, 1, 1) + scale_factors = scale_factors.view(len(batch["captions"]), 1, 1, 1) noisy_model_input = (1.0 - scale_factors) * pixel_latents + scale_factors * noise concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1) text_encoding_pipeline = text_encoding_pipeline.to("cuda") @@ -1114,7 +1113,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # flow-matching loss target = noise - pixel_latents - weighting = weighting.unsqueeze(1).unsqueeze(2).unsqueeze(3) # [4, 1, 1, 1] + weighting = weighting.view(len(batch["captions"]), 1, 1, 1) loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),1) loss = loss.mean() accelerator.backward(loss) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index e608b16b64bc..db8e69790ba0 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -17,13 +17,14 @@ import torch import torch.nn as nn import torch.nn.functional as F - +from ...loaders import PeftAdapterMixin from ...configuration_utils import ConfigMixin, register_to_config from ...models.attention import FeedForward from ...models.attention_processor import Attention from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous from ...utils import logging +from ..cache_utils import CacheMixin from ..embeddings import CogView3CombinedTimestepSizeEmbeddings from ..modeling_outputs import Transformer2DModelOutput @@ -285,6 +286,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens class CogView4Transformer2DModel(ModelMixin, ConfigMixin): +class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): r""" Args: patch_size (`int`, defaults to `2`): @@ -390,7 +392,6 @@ def forward( p = self.config.patch_size post_patch_height = height // p post_patch_width = width // p - hidden_states, encoder_hidden_states = self.patch_embed(hidden_states, encoder_hidden_states) temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) From a9f448e30da222717641fffe2b8f8df0b41f36c5 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Tue, 18 Feb 2025 16:18:01 +0800 Subject: [PATCH 05/37] 1 --- .../train_control_cogview4.py | 12 - .../models/controlnets/controlnet_cogview4.py | 465 ------------------ 2 files changed, 477 deletions(-) delete mode 100644 src/diffusers/models/controlnets/controlnet_cogview4.py diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 7d445da3ac67..8273a81c1d4b 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -1025,18 +1025,6 @@ def load_model_hook(models, input_dir): disable=not accelerator.is_local_main_process, ) - def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) - timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - - image_logs = None for epoch in range(first_epoch, args.num_train_epochs): cogview4_transformer.train() for step, batch in enumerate(train_dataloader): diff --git a/src/diffusers/models/controlnets/controlnet_cogview4.py b/src/diffusers/models/controlnets/controlnet_cogview4.py deleted file mode 100644 index 5f34c20792db..000000000000 --- a/src/diffusers/models/controlnets/controlnet_cogview4.py +++ /dev/null @@ -1,465 +0,0 @@ -# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import JointTransformerBlock -from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 -from ..embeddings import CombinedTimestepTextProjEmbeddings -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..transformers.transformer_cogview4 import CogView4TransformerBlock,CogView4PatchEmbed -from .controlnet import BaseOutput, zero_module - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class SD3ControlNetOutput(BaseOutput): - controlnet_block_samples: Tuple[torch.Tensor] - - -class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: int = 128, - patch_size: int = 2, - in_channels: int = 16, - num_layers: int = 18, - attention_head_dim: int = 64, - num_attention_heads: int = 18, - joint_attention_dim: int = 4096, - caption_projection_dim: int = 1152, - pooled_projection_dim: int = 2048, - out_channels: int = 16, - pos_embed_max_size: int = 96, - extra_conditioning_channels: int = 0, - dual_attention_layers: Tuple[int, ...] = (), - qk_norm: Optional[str] = None, - pos_embed_type: Optional[str] = "sincos", - use_pos_embed: bool = True, - force_zeros_for_pooled_projection: bool = True, - ): - super().__init__() - default_out_channels = in_channels - self.out_channels = out_channels if out_channels is not None else default_out_channels - self.inner_dim = num_attention_heads * attention_head_dim - self.pos_embed = CogView4PatchEmbed( - patch_size=patch_size, - in_channels=in_channels, - pos_embed_max_size=pos_embed_max_size, - pos_embed_type=pos_embed_type, - ) - - self.time_text_embed = CombinedTimestepTextProjEmbeddings( - embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim - ) - if joint_attention_dim is not None: - self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) - - # `attention_head_dim` is doubled to account for the mixing. - # It needs to crafted when we get the actual checkpoints. - self.transformer_blocks = nn.ModuleList( - [ - JointTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - context_pre_only=False, - qk_norm=qk_norm, - use_dual_attention=True if i in dual_attention_layers else False, - ) - for i in range(num_layers) - ] - ) - else: - self.context_embedder = None - self.transformer_blocks = nn.ModuleList( - [ - CogView4TransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - ) - for _ in range(num_layers) - ] - ) - - # controlnet_blocks - self.controlnet_blocks = nn.ModuleList([]) - for _ in range(len(self.transformer_blocks)): - controlnet_block = nn.Linear(self.inner_dim, self.inner_dim) - controlnet_block = zero_module(controlnet_block) - self.controlnet_blocks.append(controlnet_block) - pos_embed_input = CogView4PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels + extra_conditioning_channels, - embed_dim=self.inner_dim, - pos_embed_type=None, - ) - self.pos_embed_input = zero_module(pos_embed_input) - - self.gradient_checkpointing = False - - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking - def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - """ - Sets the attention processor to use [feed forward - chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). - - Parameters: - chunk_size (`int`, *optional*): - The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually - over each tensor of dim=`dim`. - dim (`int`, *optional*, defaults to `0`): - The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) - or dim=1 (sequence length). - """ - if dim not in [0, 1]: - raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") - - # By default chunk size is 1 - chunk_size = chunk_size or 1 - - def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): - if hasattr(module, "set_chunk_feed_forward"): - module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - - for child in module.children(): - fn_recursive_feed_forward(child, chunk_size, dim) - - for module in self.children(): - fn_recursive_feed_forward(module, chunk_size, dim) - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedJointAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - # Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer - # we should have handled this in conversion script - def _get_pos_embed_from_transformer(self, transformer): - pos_embed = CogView4PatchEmbed( - height=transformer.config.sample_size, - width=transformer.config.sample_size, - patch_size=transformer.config.patch_size, - in_channels=transformer.config.in_channels, - embed_dim=transformer.inner_dim, - pos_embed_max_size=transformer.config.pos_embed_max_size, - ) - pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=True) - return pos_embed - - @classmethod - def from_transformer( - cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True - ): - config = transformer.config - config["num_layers"] = num_layers or config.num_layers - config["extra_conditioning_channels"] = num_extra_conditioning_channels - controlnet = cls.from_config(config) - - if load_weights_from_transformer: - controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) - controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) - controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) - controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) - - controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input) - - return controlnet - - def forward( - self, - hidden_states: torch.FloatTensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, - encoder_hidden_states: torch.FloatTensor = None, - pooled_projections: torch.FloatTensor = None, - timestep: torch.LongTensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - """ - The [`SD3Transformer2DModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - controlnet_cond (`torch.Tensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - conditioning_scale (`float`, defaults to `1.0`): - The scale factor for ControlNet outputs. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - - if self.pos_embed is not None and hidden_states.ndim != 4: - raise ValueError("hidden_states must be 4D when pos_embed is used") - - # SD3.5 8b controlnet does not have a `pos_embed`, - # it use the `pos_embed` from the transformer to process input before passing to controlnet - elif self.pos_embed is None and hidden_states.ndim != 3: - raise ValueError("hidden_states must be 3D when pos_embed is not used") - - if self.context_embedder is not None and encoder_hidden_states is None: - raise ValueError("encoder_hidden_states must be provided when context_embedder is used") - # SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states` - elif self.context_embedder is None and encoder_hidden_states is not None: - raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used") - - if self.pos_embed is not None: - hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. - - temb = self.time_text_embed(timestep, pooled_projections) - - if self.context_embedder is not None: - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - # add - hidden_states = hidden_states + self.pos_embed_input(controlnet_cond) - - block_res_samples = () - - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - if self.context_embedder is not None: - encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - temb, - ) - else: - # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states` - hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb) - - else: - if self.context_embedder is not None: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb - ) - else: - # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states` - hidden_states = block(hidden_states, temb) - - block_res_samples = block_res_samples + (hidden_states,) - - controlnet_block_res_samples = () - for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): - block_res_sample = controlnet_block(block_res_sample) - controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) - - # 6. scaling - controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (controlnet_block_res_samples,) - - return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) - - -class SD3MultiControlNetModel(ModelMixin): - r""" - `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet - - This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be - compatible with `SD3ControlNetModel`. - - Args: - controlnets (`List[SD3ControlNetModel]`): - Provides additional conditioning to the unet during the denoising process. You must set multiple - `SD3ControlNetModel` as a list. - """ - - def __init__(self, controlnets): - super().__init__() - self.nets = nn.ModuleList(controlnets) - - def forward( - self, - hidden_states: torch.FloatTensor, - controlnet_cond: List[torch.tensor], - conditioning_scale: List[float], - pooled_projections: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - timestep: torch.LongTensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[SD3ControlNetOutput, Tuple]: - for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): - block_samples = controlnet( - hidden_states=hidden_states, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - pooled_projections=pooled_projections, - controlnet_cond=image, - conditioning_scale=scale, - joint_attention_kwargs=joint_attention_kwargs, - return_dict=return_dict, - ) - - # merge samples - if i == 0: - control_block_samples = block_samples - else: - control_block_samples = [ - control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0]) - ] - control_block_samples = (tuple(control_block_samples),) - - return control_block_samples From 2cbdf355e46f9942d7296fcf187d2753b33282c6 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Tue, 18 Feb 2025 16:47:14 +0800 Subject: [PATCH 06/37] remove initial_input_channels change for val --- examples/cogview4-control/train_control_cogview4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 8273a81c1d4b..ba53f6cdcd96 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -804,7 +804,6 @@ def main(args): cogview4_transformer.patch_embed.proj = new_linear assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0) - cogview4_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) if args.only_target_transformer_blocks: cogview4_transformer.patch_embed.proj.requires_grad_(True) From df83bf2c7415ac2a14b0555ed17338f9761129d9 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Tue, 18 Feb 2025 19:55:32 +0800 Subject: [PATCH 07/37] 1 --- examples/cogview4-control/train_control_cogview4.py | 1 + src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index ba53f6cdcd96..9410bf96b51b 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -804,6 +804,7 @@ def main(args): cogview4_transformer.patch_embed.proj = new_linear assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0) + cogview4_transformer.register_to_config(in_channels=cogview4_transformer.config.in_channels * 2, out_channels=initial_input_channels) if args.only_target_transformer_blocks: cogview4_transformer.patch_embed.proj.requires_grad_(True) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index 8c982eaed05f..fdb88bce7cd0 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -579,7 +579,7 @@ def __call__( ) # Prepare latents - latent_channels = self.transformer.config.in_channels + latent_channels = self.transformer.config.in_channels // 2 control_image = self.prepare_image( image=control_image, From 8bba67afbc3614bc9c1fa24d1ac66515f753893f Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Tue, 18 Feb 2025 20:03:05 +0800 Subject: [PATCH 08/37] update --- src/diffusers/models/transformers/transformer_cogview4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index db8e69790ba0..e9c7d3264508 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -285,7 +285,6 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens return (freqs.cos(), freqs.sin()) -class CogView4Transformer2DModel(ModelMixin, ConfigMixin): class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): r""" Args: From b9d864b7df28d6376c6c81c1ed8b9ec151e3032b Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Tue, 18 Feb 2025 20:35:44 +0800 Subject: [PATCH 09/37] use 3.5 --- examples/cogview4-control/train_control_cogview4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 9410bf96b51b..16465c988742 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -511,7 +511,7 @@ def parse_args(input_args=None): parser.add_argument( "--guidance_scale", type=float, - default=30.0, + default=3.5, help="the guidance scale used for transformer.", ) From 5d2e994bc4cd20e948d5eeda72552cad2d898b9a Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Tue, 18 Feb 2025 23:28:03 +0800 Subject: [PATCH 10/37] new loss --- examples/cogview4-control/train_control_cogview4.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 16465c988742..8878979a4645 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -41,7 +41,6 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import ( compute_density_for_timestep_sampling, - compute_loss_weighting_for_sd3, free_memory, ) from diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid @@ -804,7 +803,7 @@ def main(args): cogview4_transformer.patch_embed.proj = new_linear assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0) - cogview4_transformer.register_to_config(in_channels=cogview4_transformer.config.in_channels * 2, out_channels=initial_input_channels) + cogview4_transformer.register_to_config(in_channels=cogview4_transformer.config.in_channels * 2, out_channels=cogview4_transformer.config.in_channels) if args.only_target_transformer_blocks: cogview4_transformer.patch_embed.proj.requires_grad_(True) @@ -1097,7 +1096,7 @@ def load_model_hook(models, input_dir): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + weighting = (sigmas**-2.0).float() # flow-matching loss target = noise - pixel_latents From 95e85048e2b69facf80c014c49ed8af28f993e0e Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Wed, 19 Feb 2025 13:48:09 +0800 Subject: [PATCH 11/37] 1 --- examples/cogview4-control/train_control_cogview4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 8878979a4645..8185ac150108 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -41,6 +41,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import ( compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, free_memory, ) from diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid @@ -1096,7 +1097,7 @@ def load_model_hook(models, input_dir): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = (sigmas**-2.0).float() + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow-matching loss target = noise - pixel_latents From 7a68a3e1e7e6e50bf919d075ea3f4978c085ba40 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Wed, 19 Feb 2025 15:48:38 +0800 Subject: [PATCH 12/37] use imagetoken --- examples/cogview4-control/train_control_cogview4.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 8185ac150108..512d45b76c46 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -1056,9 +1056,7 @@ def load_model_hook(models, input_dir): timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device) sigmas = noise_scheduler_copy.sigmas[indices].to(device=pixel_latents.device) captions = batch["captions"] - token_lengths = [len(caption.split()) for caption in captions] - token_per_sample = max(token_lengths) - image_seq_lens = torch.tensor(token_per_sample // patch_size ** 2, dtype=pixel_latents.dtype, device=pixel_latents.device) + image_seq_lens = torch.tensor(pixel_latents.shape[2] * pixel_latents.shape[3] // patch_size ** 2, dtype=pixel_latents.dtype, device=pixel_latents.device) # H * W / VAE patch_size mu = torch.sqrt(image_seq_lens / 256) mu = mu * 0.75 + 0.25 scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to(dtype=pixel_latents.dtype, device=pixel_latents.device) From 2a81772076bff46b8e88e3c97416d60f87ef4af3 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Wed, 19 Feb 2025 18:44:31 +0800 Subject: [PATCH 13/37] for megatron convert --- scripts/convert_cogview4_to_diffusers.py | 14 ++++++++++++-- scripts/convert_cogview4_to_diffusers_megatron.py | 10 ++-------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py index 484c817dd938..a78319b50d91 100644 --- a/scripts/convert_cogview4_to_diffusers.py +++ b/scripts/convert_cogview4_to_diffusers.py @@ -53,8 +53,18 @@ # this is specific to `AdaLayerNormContinuous`: # diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale def swap_scale_shift(weight, dim): - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) + """ + Swap the scale and shift components in the weight tensor. + + Args: + weight (torch.Tensor): The original weight tensor. + dim (int): The dimension along which to split. + + Returns: + torch.Tensor: The modified weight tensor with scale and shift swapped. + """ + shift, scale = weight.chunk(2, dim=dim) + new_weight = torch.cat([scale, shift], dim=dim) return new_weight diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py index de5354952493..69b90054a0a5 100644 --- a/scripts/convert_cogview4_to_diffusers_megatron.py +++ b/scripts/convert_cogview4_to_diffusers_megatron.py @@ -189,14 +189,8 @@ def convert_megatron_transformer_checkpoint_to_diffusers( block_prefix = f"transformer_blocks.{i}." # AdaLayerNorm - new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift( - mega[f"decoder.layers.{i}.adaln.weight"], dim=0 - ) - new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift( - mega[f"decoder.layers.{i}.adaln.bias"], dim=0 - ) - - # QKV + new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"] + new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"] qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"] qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"] From 1d91a24ec855b30d47c5ae073f58556a4938537e Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Wed, 19 Feb 2025 19:52:01 +0800 Subject: [PATCH 14/37] 1 --- scripts/convert_cogview4_to_diffusers_megatron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py index 69b90054a0a5..b5cb7361fb71 100644 --- a/scripts/convert_cogview4_to_diffusers_megatron.py +++ b/scripts/convert_cogview4_to_diffusers_megatron.py @@ -215,7 +215,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers( # Attention Output new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[ f"decoder.layers.{i}.self_attention.linear_proj.weight" - ].T + ] new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[ f"decoder.layers.{i}.self_attention.linear_proj.bias" ] From dff4b291c61c49438f154b5d6ea2627ef75ec542 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Thu, 20 Feb 2025 00:02:28 +0800 Subject: [PATCH 15/37] train con and uc --- .../cogview4-control/train_control_cogview4.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 512d45b76c46..6a0ce12c9e3d 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -758,7 +758,6 @@ def main(args): revision=args.revision, variant=args.variant, ) - vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) cogview4_transformer = CogView4Transformer2DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", @@ -1081,9 +1080,8 @@ def load_model_hook(models, input_dir): #TODO: Should a parameter be set here for passing? This is not present in Flux. crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device) crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1) - # Predict. - model_pred = cogview4_transformer( + noise_pred_cond = cogview4_transformer( hidden_states=concatenated_noisy_model_input, encoder_hidden_states=prompt_embeds, timestep=timesteps, @@ -1093,6 +1091,16 @@ def load_model_hook(models, input_dir): return_dict=False, )[0] + noise_pred_uncond = cogview4_transformer( + hidden_states=concatenated_noisy_model_input, + encoder_hidden_states=pooled_prompt_embeds, + timestep=timesteps, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] + model_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) From b007be09e46fce34e9f01ae0ee6f6bc3992121b3 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Thu, 20 Feb 2025 01:10:30 +0800 Subject: [PATCH 16/37] 2 --- examples/cogview4-control/train_control_cogview4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 6a0ce12c9e3d..1225e1678b32 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -1066,7 +1066,7 @@ def load_model_hook(models, input_dir): with torch.no_grad(): prompt_embeds, pooled_prompt_embeds, = text_encoding_pipeline.encode_prompt( - captions, None + captions, "" ) original_size = (args.resolution, args.resolution) original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device) From 25f4e4ba569f3811ce55064a41174db4cc6bf32b Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Thu, 20 Feb 2025 20:13:26 +0800 Subject: [PATCH 17/37] remove guidance_scale --- examples/cogview4-control/train_control_cogview4.py | 2 +- scripts/convert_cogview4_to_diffusers.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 1225e1678b32..515e1551c13c 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -1100,7 +1100,7 @@ def load_model_hook(models, input_dir): crop_coords=crops_coords_top_left, return_dict=False, )[0] - model_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) + model_pred = noise_pred_uncond + (noise_pred_cond - noise_pred_uncond) # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py index a78319b50d91..b6d01c797aeb 100644 --- a/scripts/convert_cogview4_to_diffusers.py +++ b/scripts/convert_cogview4_to_diffusers.py @@ -210,6 +210,7 @@ def main(args): "norm_num_groups": 32, "sample_size": 1024, "scaling_factor": 1.0, + "shift_factor": 0.0, "force_upcast": True, "use_quant_conv": False, "use_post_quant_conv": False, From 7ffecbcbf935bf8484c95c57ef7002c9fd5e346d Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Fri, 21 Feb 2025 14:40:07 +0800 Subject: [PATCH 18/37] Update pipeline_cogview4_control.py --- .../cogview4/pipeline_cogview4_control.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index fdb88bce7cd0..08294b832672 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -45,7 +45,7 @@ >>> import torch >>> from diffusers import CogView4Pipeline - >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A photo of an astronaut riding a horse on mars" @@ -60,17 +60,11 @@ def calculate_shift( base_seq_len: int = 256, base_shift: float = 0.25, max_shift: float = 0.75, -): - # m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - # b = base_shift - m * base_seq_len - # mu = image_seq_len * m + b - # return mu - +) -> float: m = (image_seq_len / base_seq_len) ** 0.5 mu = m * max_shift + base_shift return mu - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -224,6 +218,7 @@ def _get_glm_embeds( prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds + # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -627,7 +622,7 @@ def __call__( if timesteps is None else np.array(timesteps) ) - timesteps = timesteps.astype(np.int64) + timesteps = timesteps.astype(np.int64).astype(np.float32) sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas mu = calculate_shift( image_seq_len, @@ -635,8 +630,7 @@ def __call__( self.scheduler.config.get("base_shift", 0.25), self.scheduler.config.get("max_shift", 0.75), ) - _, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu) - timesteps = torch.from_numpy(timesteps).to(device) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu) # Denoising loop transformer_dtype = self.transformer.dtype From b4e11e7d9e14cf41a5be227a825037ea1cc3d72a Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Fri, 21 Feb 2025 14:42:24 +0800 Subject: [PATCH 19/37] fix --- src/diffusers/models/transformers/transformer_cogview4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index e9c7d3264508..5248ced02c43 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -36,7 +36,7 @@ class CogView4PatchEmbed(nn.Module): def __init__( self, in_channels: int = 16, - hidden_size: int = 4096, + hidden_size: int = 2560, patch_size: int = 2, text_hidden_size: int = 4096, ): From f55e3cc44cde43176158de477f9b026410c26428 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Fri, 21 Feb 2025 15:54:33 +0800 Subject: [PATCH 20/37] use cogview4 pipeline with timestep --- .../train_control_cogview4.py | 40 +++++++++++++------ .../cogview4/pipeline_cogview4_control.py | 27 +++++++++---- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 515e1551c13c..64034d6fb740 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -37,7 +37,12 @@ from tqdm.auto import tqdm import diffusers -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, CogView4ControlPipeline,CogView4Transformer2DModel +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + CogView4ControlPipeline, + CogView4Transformer2DModel, +) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( compute_density_for_timestep_sampling, @@ -787,7 +792,7 @@ def main(args): # enable image inputs with torch.no_grad(): - patch_size = cogview4_transformer.config.patch_size + patch_size = cogview4_transformer.config.patch_size initial_input_channels = cogview4_transformer.config.in_channels * patch_size**2 new_linear = torch.nn.Linear( cogview4_transformer.patch_embed.proj.in_features * 2, @@ -803,7 +808,9 @@ def main(args): cogview4_transformer.patch_embed.proj = new_linear assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0) - cogview4_transformer.register_to_config(in_channels=cogview4_transformer.config.in_channels * 2, out_channels=cogview4_transformer.config.in_channels) + cogview4_transformer.register_to_config( + in_channels=cogview4_transformer.config.in_channels * 2, out_channels=cogview4_transformer.config.in_channels + ) if args.only_target_transformer_blocks: cogview4_transformer.patch_embed.proj.requires_grad_(True) @@ -1050,34 +1057,41 @@ def load_model_hook(models, input_dir): ) # Add noise according for cogview4 - #FIXME: The issue of variable-length training has not been resolved, here it is still extended to the longest one. + # FIXME: The issue of variable-length training has not been resolved, here it is still extended to the longest one. indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device) sigmas = noise_scheduler_copy.sigmas[indices].to(device=pixel_latents.device) captions = batch["captions"] - image_seq_lens = torch.tensor(pixel_latents.shape[2] * pixel_latents.shape[3] // patch_size ** 2, dtype=pixel_latents.dtype, device=pixel_latents.device) # H * W / VAE patch_size + image_seq_lens = torch.tensor( + pixel_latents.shape[2] * pixel_latents.shape[3] // patch_size**2, + dtype=pixel_latents.dtype, + device=pixel_latents.device, + ) # H * W / VAE patch_size mu = torch.sqrt(image_seq_lens / 256) mu = mu * 0.75 + 0.25 - scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to(dtype=pixel_latents.dtype, device=pixel_latents.device) + scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to( + dtype=pixel_latents.dtype, device=pixel_latents.device + ) scale_factors = scale_factors.view(len(batch["captions"]), 1, 1, 1) noisy_model_input = (1.0 - scale_factors) * pixel_latents + scale_factors * noise concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1) text_encoding_pipeline = text_encoding_pipeline.to("cuda") with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds, = text_encoding_pipeline.encode_prompt( - captions, "" - ) + ( + prompt_embeds, + pooled_prompt_embeds, + ) = text_encoding_pipeline.encode_prompt(captions, "") original_size = (args.resolution, args.resolution) original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device) - target_size = (args.resolution,args.resolution) + target_size = (args.resolution, args.resolution) target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device) target_size = target_size.repeat(len(batch["captions"]), 1) original_size = original_size.repeat(len(batch["captions"]), 1) - #TODO: Should a parameter be set here for passing? This is not present in Flux. + # TODO: Should a parameter be set here for passing? This is not present in Flux. crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device) crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1) # Predict. @@ -1108,7 +1122,9 @@ def load_model_hook(models, input_dir): target = noise - pixel_latents weighting = weighting.view(len(batch["captions"]), 1, 1, 1) - loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),1) + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1 + ) loss = loss.mean() accelerator.backward(loss) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index 08294b832672..626c74a114ee 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -65,6 +65,7 @@ def calculate_shift( mu = m * max_shift + base_shift return mu + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -97,10 +98,19 @@ def retrieve_timesteps( `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -109,9 +119,8 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." @@ -630,8 +639,10 @@ def __call__( self.scheduler.config.get("base_shift", 0.25), self.scheduler.config.get("max_shift", 0.75), ) - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu) - + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) # Denoising loop transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) From 29b0c81ea6228c6354c8dd418fad6776b714e559 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Mon, 24 Feb 2025 17:06:43 +0800 Subject: [PATCH 21/37] update shift_factor --- scripts/convert_cogview4_to_diffusers_megatron.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py index b5cb7361fb71..1701a6283b8c 100644 --- a/scripts/convert_cogview4_to_diffusers_megatron.py +++ b/scripts/convert_cogview4_to_diffusers_megatron.py @@ -311,6 +311,7 @@ def main(args): "norm_num_groups": 32, "sample_size": 1024, "scaling_factor": 1.0, + "shift_factor": 0.0, "force_upcast": True, "use_quant_conv": False, "use_post_quant_conv": False, From 90830ed150fa7a19f0302bada45203d4880463bc Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Wed, 26 Feb 2025 16:59:38 +0800 Subject: [PATCH 22/37] remove the uncond --- .../train_control_cogview4.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 64034d6fb740..e97bfab75cd7 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -18,6 +18,7 @@ import logging import math import os +import random import shutil from contextlib import nullcontext from pathlib import Path @@ -1094,6 +1095,14 @@ def load_model_hook(models, input_dir): # TODO: Should a parameter be set here for passing? This is not present in Flux. crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device) crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1) + + # this could be optimized by not having to do any text encoding and just + # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds` + if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts: + # 这里,直接将 pooled_prompt_embeds 16个 pad token 提供给 prompt_embeds + prompt_embeds = pooled_prompt_embeds + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") # Predict. noise_pred_cond = cogview4_transformer( hidden_states=concatenated_noisy_model_input, @@ -1104,17 +1113,6 @@ def load_model_hook(models, input_dir): crop_coords=crops_coords_top_left, return_dict=False, )[0] - - noise_pred_uncond = cogview4_transformer( - hidden_states=concatenated_noisy_model_input, - encoder_hidden_states=pooled_prompt_embeds, - timestep=timesteps, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - return_dict=False, - )[0] - model_pred = noise_pred_uncond + (noise_pred_cond - noise_pred_uncond) # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) @@ -1123,7 +1121,7 @@ def load_model_hook(models, input_dir): weighting = weighting.view(len(batch["captions"]), 1, 1, 1) loss = torch.mean( - (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1 + (weighting.float() * (noise_pred_cond.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1 ) loss = loss.mean() accelerator.backward(loss) From 71f9235ef0683fb16c3c9fd8773616f434eb13b8 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Wed, 26 Feb 2025 18:11:02 +0800 Subject: [PATCH 23/37] add max length --- .../cogview4-control/train_control_cogview4.py | 17 ++++++++++++----- .../cogview4/pipeline_cogview4_control.py | 13 ++++++++----- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index e97bfab75cd7..a5e50acf8a9d 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -132,6 +132,8 @@ def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step, control_image=validation_image, num_inference_steps=50, guidance_scale=args.guidance_scale, + max_sequence_length=max_sequence_length, # For downstream task training usage, training can be performed on a batch basis. + padding_type="max_length", generator=generator, height=args.resolution, width=args.resolution, @@ -267,6 +269,9 @@ def parse_args(input_args=None): " resolution" ), ) + parser.add_argument( + "--max_sequence_length", type=int, default=128, help="The maximum sequence length for the prompt." + ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -1079,10 +1084,12 @@ def load_model_hook(models, input_dir): text_encoding_pipeline = text_encoding_pipeline.to("cuda") with torch.no_grad(): - ( - prompt_embeds, - pooled_prompt_embeds, - ) = text_encoding_pipeline.encode_prompt(captions, "") + # Since the batch will be padded, max_length should be used for padding. + prompt_embeds,pooled_prompt_embeds,= text_encoding_pipeline.encode_prompt( + captions, "", + max_sequence_length=args.max_sequence_length, + padding_type="max_length" + ) original_size = (args.resolution, args.resolution) original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device) @@ -1099,7 +1106,7 @@ def load_model_hook(models, input_dir): # this could be optimized by not having to do any text encoding and just # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds` if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts: - # 这里,直接将 pooled_prompt_embeds 16个 pad token 提供给 prompt_embeds + # Here, we directly pass 16 pad tokens from pooled_prompt_embeds to prompt_embeds. prompt_embeds = pooled_prompt_embeds if args.offload: text_encoding_pipeline = text_encoding_pipeline.to("cpu") diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index 626c74a114ee..3626c198d35a 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -182,6 +182,7 @@ def _get_glm_embeds( prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 1024, + padding_type: str = "longest", device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -193,7 +194,7 @@ def _get_glm_embeds( text_inputs = self.tokenizer( prompt, - padding="longest", # not use max length + padding=padding_type, max_length=max_sequence_length, truncation=True, add_special_tokens=True, @@ -239,6 +240,7 @@ def encode_prompt( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 1024, + padding_type: str = "longest", ): r""" Encodes the prompt into text encoder hidden states. @@ -275,9 +277,8 @@ def encode_prompt( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - if prompt_embeds is None: - prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, device, dtype) + prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, padding_type, device, dtype) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" @@ -296,7 +297,7 @@ def encode_prompt( ) negative_prompt_embeds = self._get_glm_embeds( - negative_prompt, num_images_per_prompt, max_sequence_length, device, dtype + negative_prompt, num_images_per_prompt, max_sequence_length, "longest", device, dtype ) return prompt_embeds, negative_prompt_embeds @@ -450,6 +451,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 1024, + padding_type: str = "longest", # For downstream tasks, it can be modified to use max_length for implementation. ) -> Union[CogView4PipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -579,7 +581,8 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, - device=device, + padding_type=padding_type, + device=device ) # Prepare latents From 19d7d27c38cc861c9dc832ab61b143eab3b72873 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Thu, 27 Feb 2025 16:02:43 +0800 Subject: [PATCH 24/37] change convert and use GLMModel instead of GLMForCasualLM --- examples/cogview4-control/train_control_cogview4.py | 2 +- scripts/convert_cogview4_to_diffusers_megatron.py | 4 ++-- src/diffusers/pipelines/cogview4/pipeline_cogview4.py | 2 +- src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index a5e50acf8a9d..c6fc2cb24f8c 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -660,7 +660,7 @@ def prepare_train_dataset(dataset, accelerator): [ transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + transforms.Lambda(lambda x: x * 2 - 1) ] ) diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py index 1701a6283b8c..19124231a7b6 100644 --- a/scripts/convert_cogview4_to_diffusers_megatron.py +++ b/scripts/convert_cogview4_to_diffusers_megatron.py @@ -25,7 +25,7 @@ import torch from tqdm import tqdm -from transformers import GlmForCausalLM, PreTrainedTokenizerFast +from transformers import GlmModel, PreTrainedTokenizerFast from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint @@ -326,7 +326,7 @@ def main(args): # Load the text encoder and tokenizer text_encoder_id = "THUDM/glm-4-9b-hf" tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id) - text_encoder = GlmForCausalLM.from_pretrained( + text_encoder = GlmModel.from_pretrained( text_encoder_id, cache_dir=args.text_encoder_cache_dir, torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32, diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 097d1b6aed41..f2c047fb22c9 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -215,7 +215,7 @@ def _get_glm_embeds( ) text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) prompt_embeds = self.text_encoder( - text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True + text_input_ids.to(self.text_encoder.device), output_hidden_states=True ).hidden_states[-2] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index 3626c198d35a..ea956a8b2fbc 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -219,7 +219,7 @@ def _get_glm_embeds( ) text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) prompt_embeds = self.text_encoder( - text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True + text_input_ids.to(self.text_encoder.device), output_hidden_states=True ).hidden_states[-2] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) From 2f74c4e47484984118214c06ce45e6c83cdceabc Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Thu, 27 Feb 2025 16:42:15 +0800 Subject: [PATCH 25/37] fix --- examples/cogview4-control/train_control_cogview4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index c6fc2cb24f8c..deb0b81bd60a 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -132,7 +132,7 @@ def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step, control_image=validation_image, num_inference_steps=50, guidance_scale=args.guidance_scale, - max_sequence_length=max_sequence_length, # For downstream task training usage, training can be performed on a batch basis. + max_sequence_length=args.max_sequence_length, # For downstream task training usage, training can be performed on a batch basis. padding_type="max_length", generator=generator, height=args.resolution, From 264060eeb390baabe458bf90bb0677df01c51f3e Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 28 Feb 2025 08:40:13 +0000 Subject: [PATCH 26/37] [cogview4] Add attention mask support to transformer model --- .../transformers/transformer_cogview4.py | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 5248ced02c43..6717f3b95c69 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -17,16 +17,17 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ...loaders import PeftAdapterMixin + from ...configuration_utils import ConfigMixin, register_to_config from ...models.attention import FeedForward from ...models.attention_processor import Attention from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous from ...utils import logging -from ..cache_utils import CacheMixin from ..embeddings import CogView3CombinedTimestepSizeEmbeddings from ..modeling_outputs import Transformer2DModelOutput +from ...loaders import PeftAdapterMixin +from ..cache_utils import CacheMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -123,10 +124,11 @@ def __call__( attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape + batch_size, image_seq_length, embed_dim = hidden_states.shape hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) # 1. QKV projections @@ -156,6 +158,17 @@ def __call__( ) # 4. Attention + if attention_mask is not None: + # construct attention_mask for concated sequence + text_attention_mask = attention_mask.float().to(query.device) + attention_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device) + attention_mask[:, :text_seq_length] = text_attention_mask + attention_mask = attention_mask.unsqueeze(2) + attention_mask_matrix = attention_mask @ attention_mask.mT + attention_mask_matrix = attention_mask_matrix == 1 + attention_mask_matrix = attention_mask_matrix.unsqueeze(1) + attention_mask = attention_mask_matrix + hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -203,6 +216,8 @@ def forward( encoder_hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: # 1. Timestep conditioning ( @@ -223,6 +238,8 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + **kwargs, ) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) @@ -233,8 +250,8 @@ def forward( 1 + c_scale_mlp.unsqueeze(1) ) + c_shift_mlp.unsqueeze(1) - ff_output = self.ff(norm_hidden_states) - ff_output_context = self.ff(norm_encoder_hidden_states) + ff_output = self.ff(norm_hidden_states, **kwargs) + ff_output_context = self.ff(norm_encoder_hidden_states, **kwargs) hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) @@ -381,6 +398,8 @@ def forward( target_size: torch.Tensor, crop_coords: torch.Tensor, return_dict: bool = True, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[torch.Tensor, Transformer2DModelOutput]: batch_size, num_channels, height, width = hidden_states.shape @@ -391,6 +410,7 @@ def forward( p = self.config.patch_size post_patch_height = height // p post_patch_width = width // p + hidden_states, encoder_hidden_states = self.patch_embed(hidden_states, encoder_hidden_states) temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) @@ -400,11 +420,11 @@ def forward( for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, temb, image_rotary_emb + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs ) else: hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, image_rotary_emb + hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs ) # 4. Output norm & projection From 9a10cebb5c81957773f4887822d3a782fcc67fa7 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Tue, 4 Mar 2025 10:26:45 +0000 Subject: [PATCH 27/37] [fix] Add attention mask for padded token --- .../cogview4-control/train_control_cogview4.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index deb0b81bd60a..2c8c991d9369 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -40,9 +40,9 @@ import diffusers from diffusers import ( AutoencoderKL, - FlowMatchEulerDiscreteScheduler, CogView4ControlPipeline, CogView4Transformer2DModel, + FlowMatchEulerDiscreteScheduler, ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( @@ -977,6 +977,7 @@ def load_model_hook(models, input_dir): text_encoding_pipeline = CogView4ControlPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype ) + tokenizer = text_encoding_pipeline.tokenizer # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: @@ -1043,6 +1044,16 @@ def load_model_hook(models, input_dir): with accelerator.accumulate(cogview4_transformer): # Convert images to latent space # vae encode + prompts = batch["captions"] + attention_mask = tokenizer( + prompts, + padding="longest", # not use max length + max_length=args.max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).attention_mask.float() + pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype) control_latents = encode_images( batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype @@ -1119,6 +1130,7 @@ def load_model_hook(models, input_dir): target_size=target_size, crop_coords=crops_coords_top_left, return_dict=False, + attention_mask=attention_mask, )[0] # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss From 692e5cc3c0bf0df8edc1c46200fc8d06c9e0949f Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Wed, 5 Mar 2025 16:16:54 +0800 Subject: [PATCH 28/37] update --- .../train_control_cogview4.py | 20 +++++++---------- .../transformers/transformer_cogview4.py | 22 ++++++++----------- .../cogview4/pipeline_cogview4_control.py | 22 +++++++------------ 3 files changed, 25 insertions(+), 39 deletions(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 2c8c991d9369..9649fa88e3ba 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -132,7 +132,7 @@ def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step, control_image=validation_image, num_inference_steps=50, guidance_scale=args.guidance_scale, - max_sequence_length=args.max_sequence_length, # For downstream task training usage, training can be performed on a batch basis. + max_sequence_length=args.max_sequence_length, # For downstream task training usage, training can be performed on a batch basis. padding_type="max_length", generator=generator, height=args.resolution, @@ -660,7 +660,7 @@ def prepare_train_dataset(dataset, accelerator): [ transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), - transforms.Lambda(lambda x: x * 2 - 1) + transforms.Lambda(lambda x: x * 2 - 1), ] ) @@ -1074,7 +1074,6 @@ def load_model_hook(models, input_dir): ) # Add noise according for cogview4 - # FIXME: The issue of variable-length training has not been resolved, here it is still extended to the longest one. indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device) sigmas = noise_scheduler_copy.sigmas[indices].to(device=pixel_latents.device) @@ -1095,12 +1094,10 @@ def load_model_hook(models, input_dir): text_encoding_pipeline = text_encoding_pipeline.to("cuda") with torch.no_grad(): - # Since the batch will be padded, max_length should be used for padding. - prompt_embeds,pooled_prompt_embeds,= text_encoding_pipeline.encode_prompt( - captions, "", - max_sequence_length=args.max_sequence_length, - padding_type="max_length" - ) + ( + prompt_embeds, + pooled_prompt_embeds, + ) = text_encoding_pipeline.encode_prompt(captions, "") original_size = (args.resolution, args.resolution) original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device) @@ -1109,8 +1106,6 @@ def load_model_hook(models, input_dir): target_size = target_size.repeat(len(batch["captions"]), 1) original_size = original_size.repeat(len(batch["captions"]), 1) - - # TODO: Should a parameter be set here for passing? This is not present in Flux. crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device) crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1) @@ -1140,7 +1135,8 @@ def load_model_hook(models, input_dir): weighting = weighting.view(len(batch["captions"]), 1, 1, 1) loss = torch.mean( - (weighting.float() * (noise_pred_cond.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1 + (weighting.float() * (noise_pred_cond.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, ) loss = loss.mean() accelerator.backward(loss) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 6717f3b95c69..a79589973015 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -157,21 +157,17 @@ def __call__( key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 ) - # 4. Attention + # 4. Attention and Attention Mask if attention_mask is not None: - # construct attention_mask for concated sequence text_attention_mask = attention_mask.float().to(query.device) - attention_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device) - attention_mask[:, :text_seq_length] = text_attention_mask - attention_mask = attention_mask.unsqueeze(2) - attention_mask_matrix = attention_mask @ attention_mask.mT - attention_mask_matrix = attention_mask_matrix == 1 - attention_mask_matrix = attention_mask_matrix.unsqueeze(1) - attention_mask = attention_mask_matrix - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + actual_text_seq_length = text_attention_mask.size(1) + new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device) + new_attention_mask[:, :actual_text_seq_length] = text_attention_mask + new_attention_mask = new_attention_mask.unsqueeze(2) + attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2) + attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype) + + hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.type_as(query) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index ea956a8b2fbc..59e55d74210d 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -144,13 +144,11 @@ class CogView4ControlPipeline(DiffusionPipeline): Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`T5EncoderModel`]): - Frozen text-encoder. CogView4 uses - [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the - [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. - tokenizer (`T5Tokenizer`): + text_encoder ([`GLMModel`]): + Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf). + tokenizer (`PreTrainedTokenizer`): Tokenizer of class - [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). transformer ([`CogView4Transformer2DModel`]): A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): @@ -182,7 +180,6 @@ def _get_glm_embeds( prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 1024, - padding_type: str = "longest", device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -194,7 +191,7 @@ def _get_glm_embeds( text_inputs = self.tokenizer( prompt, - padding=padding_type, + padding="longest", # not use max length max_length=max_sequence_length, truncation=True, add_special_tokens=True, @@ -240,7 +237,6 @@ def encode_prompt( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 1024, - padding_type: str = "longest", ): r""" Encodes the prompt into text encoder hidden states. @@ -278,7 +274,7 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, padding_type, device, dtype) + prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, device, dtype) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" @@ -297,7 +293,7 @@ def encode_prompt( ) negative_prompt_embeds = self._get_glm_embeds( - negative_prompt, num_images_per_prompt, max_sequence_length, "longest", device, dtype + negative_prompt, num_images_per_prompt, max_sequence_length, device, dtype ) return prompt_embeds, negative_prompt_embeds @@ -451,7 +447,6 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 1024, - padding_type: str = "longest", # For downstream tasks, it can be modified to use max_length for implementation. ) -> Union[CogView4PipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -581,8 +576,7 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, - padding_type=padding_type, - device=device + device=device, ) # Prepare latents From fc3830c922d3800215d54199d311383566f26976 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Thu, 6 Mar 2025 20:02:43 +0800 Subject: [PATCH 29/37] remove padding type --- examples/cogview4-control/train_control_cogview4.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 9649fa88e3ba..1f0d6ef4ed58 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -132,8 +132,6 @@ def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step, control_image=validation_image, num_inference_steps=50, guidance_scale=args.guidance_scale, - max_sequence_length=args.max_sequence_length, # For downstream task training usage, training can be performed on a batch basis. - padding_type="max_length", generator=generator, height=args.resolution, width=args.resolution, From 98a2417ce9c08d401e6497b5d2124bdfdc70bfde Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Thu, 6 Mar 2025 20:03:09 +0800 Subject: [PATCH 30/37] Update train_control_cogview4.py --- examples/cogview4-control/train_control_cogview4.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 1f0d6ef4ed58..506ca0225bf7 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -132,6 +132,7 @@ def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step, control_image=validation_image, num_inference_steps=50, guidance_scale=args.guidance_scale, + max_sequence_length=args.max_sequence_length, generator=generator, height=args.resolution, width=args.resolution, From c774f4536be137e0d48899680835c4467109b4e6 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Wed, 12 Mar 2025 16:48:25 +0800 Subject: [PATCH 31/37] resolve conflicts with #10981 --- .../transformers/transformer_cogview4.py | 78 ++++++++++++------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index a79589973015..43a20dec20a2 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...models.attention import FeedForward -from ...models.attention_processor import Attention -from ...models.modeling_utils import ModelMixin -from ...models.normalization import AdaLayerNormContinuous -from ...utils import logging +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import FeedForward +from ..attention_processor import Attention from ..embeddings import CogView3CombinedTimestepSizeEmbeddings from ..modeling_outputs import Transformer2DModelOutput -from ...loaders import PeftAdapterMixin +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous from ..cache_utils import CacheMixin @@ -124,7 +124,7 @@ def __call__( attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape @@ -157,7 +157,7 @@ def __call__( key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 ) - # 4. Attention and Attention Mask + # 4. Attention if attention_mask is not None: text_attention_mask = attention_mask.float().to(query.device) actual_text_seq_length = text_attention_mask.size(1) @@ -167,7 +167,9 @@ def __call__( attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2) attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype) - hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -246,8 +248,8 @@ def forward( 1 + c_scale_mlp.unsqueeze(1) ) + c_shift_mlp.unsqueeze(1) - ff_output = self.ff(norm_hidden_states, **kwargs) - ff_output_context = self.ff(norm_encoder_hidden_states, **kwargs) + ff_output = self.ff(norm_hidden_states) + ff_output_context = self.ff(norm_encoder_hidden_states) hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) @@ -258,30 +260,34 @@ class CogView4RotaryPosEmbed(nn.Module): def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None: super().__init__() + self.dim = dim self.patch_size = patch_size self.rope_axes_dim = rope_axes_dim - - dim_h, dim_w = dim // 2, dim // 2 - h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)) - w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)) - h_seq = torch.arange(self.rope_axes_dim[0]) - w_seq = torch.arange(self.rope_axes_dim[1]) - self.freqs_h = torch.outer(h_seq, h_inv_freq) - self.freqs_w = torch.outer(w_seq, w_inv_freq) + self.theta = theta def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, num_channels, height, width = hidden_states.shape height, width = height // self.patch_size, width // self.patch_size - h_idx = torch.arange(height) - w_idx = torch.arange(width) + dim_h, dim_w = self.dim // 2, self.dim // 2 + h_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h) + ) + w_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w) + ) + h_seq = torch.arange(self.rope_axes_dim[0]) + w_seq = torch.arange(self.rope_axes_dim[1]) + freqs_h = torch.outer(h_seq, h_inv_freq) + freqs_w = torch.outer(w_seq, w_inv_freq) + + h_idx = torch.arange(height, device=freqs_h.device) + w_idx = torch.arange(width, device=freqs_w.device) inner_h_idx = h_idx * self.rope_axes_dim[0] // height inner_w_idx = w_idx * self.rope_axes_dim[1] // width - self.freqs_h = self.freqs_h.to(hidden_states.device) - self.freqs_w = self.freqs_w.to(hidden_states.device) - freqs_h = self.freqs_h[inner_h_idx] - freqs_w = self.freqs_w[inner_w_idx] + freqs_h = freqs_h[inner_h_idx] + freqs_w = freqs_w[inner_w_idx] # Create position matrices for height and width # [height, 1, dim//4] and [1, width, dim//4] @@ -393,10 +399,26 @@ def forward( original_size: torch.Tensor, target_size: torch.Tensor, crop_coords: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + batch_size, num_channels, height, width = hidden_states.shape # 1. RoPE @@ -431,6 +453,10 @@ def forward( hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) From 8abca19db1c4ac4938dfa5349e2e8a9f70e699d5 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Wed, 12 Mar 2025 18:10:05 +0800 Subject: [PATCH 32/37] add control convert --- .../convert_cogview4_to_diffusers_megatron.py | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py index 19124231a7b6..df5e540ec9ef 100644 --- a/scripts/convert_cogview4_to_diffusers_megatron.py +++ b/scripts/convert_cogview4_to_diffusers_megatron.py @@ -27,7 +27,13 @@ from tqdm import tqdm from transformers import GlmModel, PreTrainedTokenizerFast -from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler +from diffusers import ( + AutoencoderKL, + CogView4Pipeline, + CogView4ControlPipeline, + CogView4Transformer2DModel, + FlowMatchEulerDiscreteScheduler, +) from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint @@ -112,6 +118,12 @@ default=128, help="Maximum size for positional embeddings.", ) +parser.add_argument( + "--control", + action="store_true", + default=False, + help="Whether to use control model.", +) args = parser.parse_args() @@ -156,7 +168,9 @@ def convert_megatron_transformer_checkpoint_to_diffusers( new_state_dict = {} # Patch Embedding - new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64) + new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape( + hidden_size, 128 if args.control else 64, 64 + ) new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"] new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"] new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"] @@ -340,13 +354,22 @@ def main(args): ) # Create the pipeline - pipe = CogView4Pipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - vae=vae, - transformer=transformer, - scheduler=scheduler, - ) + if args.control: + pipe = CogView4ControlPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + else: + pipe = CogView4Pipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) # Save the converted pipeline pipe.save_pretrained( From 347dd17e8ac6d1d28c8862ff3d91226f131c21af Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Thu, 13 Mar 2025 15:09:34 +0800 Subject: [PATCH 33/37] use control format --- .../convert_cogview4_to_diffusers_megatron.py | 8 ++++---- .../cogview4/pipeline_cogview4_control.py | 18 +++++++++++------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py index df5e540ec9ef..7c41301ed663 100644 --- a/scripts/convert_cogview4_to_diffusers_megatron.py +++ b/scripts/convert_cogview4_to_diffusers_megatron.py @@ -162,14 +162,14 @@ def convert_megatron_transformer_checkpoint_to_diffusers( Returns: dict: The converted state dictionary compatible with Diffusers. """ - ckpt = torch.load(ckpt_path, map_location="cpu") + ckpt = torch.load(ckpt_path, map_location="cpu",weights_only=False) mega = ckpt["model"] new_state_dict = {} # Patch Embedding new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape( - hidden_size, 128 if args.control else 64, 64 + hidden_size, 128 if args.control else 64 ) new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"] new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"] @@ -260,7 +260,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config): Returns: dict: The converted VAE state dictionary compatible with Diffusers. """ - original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + original_state_dict = torch.load(ckpt_path, map_location="cpu",weights_only=False)["state_dict"] return convert_ldm_vae_checkpoint(original_state_dict, vae_config) @@ -294,7 +294,7 @@ def main(args): ) transformer = CogView4Transformer2DModel( patch_size=2, - in_channels=16, + in_channels=32 if args.control else 16, num_layers=args.num_layers, attention_head_dim=args.attention_head_dim, num_attention_heads=args.num_heads, diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index 59e55d74210d..3b7539c4a88e 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -46,15 +46,18 @@ >>> from diffusers import CogView4Pipeline >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") - - >>> prompt = "A photo of an astronaut riding a horse on mars" - >>> image = pipe(prompt).images[0] - >>> image.save("output.png") + >>> control_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ... ) + >>> prompt = "A bird in space" + >>> image = pipe( + ... prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5) + ... ).images[0] + >>> image.save("cogview4-control.png") ``` """ - +# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, @@ -175,6 +178,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline._get_glm_embeds def _get_glm_embeds( self, prompt: Union[str, List[str]] = None, @@ -341,7 +345,7 @@ def prepare_image( # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt - image = image.repeat_interleave(repeat_by, dim=0) + image = image.repeat_interleave(repeat_by, dim=0, output_size=image.shape[0] * repeat_by) image = image.to(device=device, dtype=dtype) From 775bb8ce1c8da25dd388c2a8be9998b7e0e3614e Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Thu, 13 Mar 2025 15:16:53 +0800 Subject: [PATCH 34/37] fix --- .../pipelines/cogview4/pipeline_cogview4.py | 16 ++++++++++---- .../cogview4/pipeline_cogview4_control.py | 21 +++++++++++++++++-- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index a60fcc4ffc8b..c27a1a19774d 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -389,14 +389,18 @@ def do_classifier_free_guidance(self): def num_timesteps(self): return self._num_timesteps - @property - def interrupt(self): - return self._interrupt - @property def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -533,6 +537,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # Default call parameters @@ -610,6 +615,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -661,6 +667,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False, generator=generator)[0] diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index 3b7539c4a88e..644dc169349c 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -14,7 +14,7 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union, Any import numpy as np import torch @@ -43,7 +43,7 @@ Examples: ```python >>> import torch - >>> from diffusers import CogView4Pipeline + >>> from diffusers import CogView4ControlPipeline >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16) >>> control_image = load_image( @@ -420,6 +420,14 @@ def do_classifier_free_guidance(self): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -446,6 +454,7 @@ def __call__( crops_coords_top_left: Tuple[int, int] = (0, 0), output_type: str = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -559,6 +568,8 @@ def __call__( negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # Default call parameters @@ -652,6 +663,8 @@ def __call__( for i, t in enumerate(timesteps): if self.interrupt: continue + + self._current_timestep = t latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -664,6 +677,7 @@ def __call__( original_size=original_size, target_size=target_size, crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -676,6 +690,7 @@ def __call__( original_size=original_size, target_size=target_size, crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -700,6 +715,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False, generator=generator)[0] From 985baa92df61d0fe939465984b7b74456bef7ecf Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Thu, 13 Mar 2025 15:20:24 +0800 Subject: [PATCH 35/37] add missing import --- src/diffusers/pipelines/cogview4/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/cogview4/__init__.py b/src/diffusers/pipelines/cogview4/__init__.py index 531cea7d7c66..6a365e17fee7 100644 --- a/src/diffusers/pipelines/cogview4/__init__.py +++ b/src/diffusers/pipelines/cogview4/__init__.py @@ -32,6 +32,7 @@ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_cogview4 import CogView4Pipeline + from .pipeline_cogview4_control import CogView4ControlPipeline else: import sys From 88abb398cc5e42bcc008bca602b3bcfb946f142f Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Sat, 15 Mar 2025 16:43:06 +0800 Subject: [PATCH 36/37] update with cogview4 formate --- scripts/convert_cogview4_to_diffusers_megatron.py | 2 +- src/diffusers/__init__.py | 2 +- src/diffusers/models/transformers/transformer_cogview4.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/auto_pipeline.py | 2 +- src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py | 4 ++-- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py index 7c41301ed663..bef3dcbbd603 100644 --- a/scripts/convert_cogview4_to_diffusers_megatron.py +++ b/scripts/convert_cogview4_to_diffusers_megatron.py @@ -29,8 +29,8 @@ from diffusers import ( AutoencoderKL, - CogView4Pipeline, CogView4ControlPipeline, + CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler, ) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d1381a29ff81..0b4794d5cd35 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -890,8 +890,8 @@ CogVideoXPipeline, CogVideoXVideoToVideoPipeline, CogView3PlusPipeline, - CogView4Pipeline, CogView4ControlPipeline, + CogView4Pipeline, ConsisIDPipeline, CycleDiffusionPipeline, EasyAnimateControlPipeline, diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 43a20dec20a2..41c4cbbf97c7 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -23,11 +23,11 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..embeddings import CogView3CombinedTimestepSizeEmbeddings from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous -from ..cache_utils import CacheMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a561bad2c495..466b8b613b9d 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -511,7 +511,7 @@ CogVideoXVideoToVideoPipeline, ) from .cogview3 import CogView3PlusPipeline - from .cogview4 import CogView4Pipeline, CogView4ControlPipeline + from .cogview4 import CogView4ControlPipeline, CogView4Pipeline from .consisid import ConsisIDPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 4a03e4467af3..6a5f6098b6fb 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -22,7 +22,7 @@ from ..utils import is_sentencepiece_available from .aura_flow import AuraFlowPipeline from .cogview3 import CogView3PlusPipeline -from .cogview4 import CogView4Pipeline, CogView4ControlPipeline +from .cogview4 import CogView4ControlPipeline, CogView4Pipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index 644dc169349c..ef9979fee830 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -14,14 +14,14 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Tuple, Union, Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from transformers import AutoTokenizer, GlmModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...models import AutoencoderKL, CogView4Transformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler From 3e3387ec5019bc90c82e09897e9e3893f37ae007 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 15 Mar 2025 15:35:34 +0100 Subject: [PATCH 37/37] make style --- .../convert_cogview4_to_diffusers_megatron.py | 4 +- src/diffusers/__init__.py | 2 +- .../cogview4/pipeline_cogview4_control.py | 45 ++++++++----------- .../dummy_torch_and_transformers_objects.py | 6 ++- 4 files changed, 26 insertions(+), 31 deletions(-) diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py index bef3dcbbd603..8faeccb13888 100644 --- a/scripts/convert_cogview4_to_diffusers_megatron.py +++ b/scripts/convert_cogview4_to_diffusers_megatron.py @@ -162,7 +162,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers( Returns: dict: The converted state dictionary compatible with Diffusers. """ - ckpt = torch.load(ckpt_path, map_location="cpu",weights_only=False) + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) mega = ckpt["model"] new_state_dict = {} @@ -260,7 +260,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config): Returns: dict: The converted VAE state dictionary compatible with Diffusers. """ - original_state_dict = torch.load(ckpt_path, map_location="cpu",weights_only=False)["state_dict"] + original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"] return convert_ldm_vae_checkpoint(original_state_dict, vae_config) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0b4794d5cd35..65e9bb695e6e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -345,8 +345,8 @@ "CogVideoXPipeline", "CogVideoXVideoToVideoPipeline", "CogView3PlusPipeline", - "CogView4Pipeline", "CogView4ControlPipeline", + "CogView4Pipeline", "ConsisIDPipeline", "CycleDiffusionPipeline", "EasyAnimateControlPipeline", diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index ef9979fee830..b22705ed05c9 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -50,13 +50,12 @@ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" ... ) >>> prompt = "A bird in space" - >>> image = pipe( - ... prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5) - ... ).images[0] + >>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0] >>> image.save("cogview4-control.png") ``` """ + # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift def calculate_shift( image_seq_len, @@ -101,19 +100,10 @@ def retrieve_timesteps( `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if timesteps is not None and sigmas is not None: - if not accepts_timesteps and not accepts_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep or sigma schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif timesteps is not None and sigmas is None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -122,8 +112,9 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) - elif timesteps is None and sigmas is not None: - if not accepts_sigmas: + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." @@ -182,7 +173,6 @@ def __init__( def _get_glm_embeds( self, prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, max_sequence_length: int = 1024, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, @@ -191,7 +181,6 @@ def _get_glm_embeds( dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) text_inputs = self.tokenizer( prompt, @@ -224,9 +213,6 @@ def _get_glm_embeds( ).hidden_states[-2] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt @@ -277,8 +263,13 @@ def encode_prompt( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] + if prompt_embeds is None: - prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, device, dtype) + prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" @@ -296,9 +287,11 @@ def encode_prompt( " the batch size of `prompt`." ) - negative_prompt_embeds = self._get_glm_embeds( - negative_prompt, num_images_per_prompt, max_sequence_length, device, dtype - ) + negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype) + + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds, negative_prompt_embeds diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 4af260d27391..ae606c3709e5 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -362,7 +362,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class CogView4Pipeline(metaclass=DummyObject): +class CogView4ControlPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -376,7 +376,8 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class CogView4ControlPipeline(metaclass=DummyObject): + +class CogView4Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -390,6 +391,7 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + class ConsisIDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]