diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 04c787ee3e84..a91180b11825 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy as np import torch @@ -41,7 +41,9 @@ class CogVideoXSafeConv3d(nn.Conv3d): """ def forward(self, input: torch.Tensor) -> torch.Tensor: - memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + memory_count = ( + (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3 + ) # Set to 2GB, suitable for CuDNN if memory_count > 2: @@ -115,34 +117,24 @@ def __init__( dilation=dilation, ) - self.conv_cache = None - - def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: + def fake_context_parallel_forward( + self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None + ) -> torch.Tensor: kernel_size = self.time_kernel_size if kernel_size > 1: - cached_inputs = ( - [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) - ) + cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) inputs = torch.cat(cached_inputs + [inputs], dim=2) return inputs - def _clear_fake_context_parallel_cache(self): - del self.conv_cache - self.conv_cache = None - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - inputs = self.fake_context_parallel_forward(inputs) - - self._clear_fake_context_parallel_cache() - # Note: we could move these to the cpu for a lower maximum memory usage but its only a few - # hundred megabytes and so let's not do it for now - self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: + inputs = self.fake_context_parallel_forward(inputs, conv_cache) + conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) inputs = F.pad(inputs, padding_2d, mode="constant", value=0) output = self.conv(inputs) - return output + return output, conv_cache class CogVideoXSpatialNorm3D(nn.Module): @@ -172,7 +164,12 @@ def __init__( self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) - def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + def forward( + self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None + ) -> torch.Tensor: + new_conv_cache = {} + conv_cache = conv_cache or {} + if f.shape[2] > 1 and f.shape[2] % 2 == 1: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] @@ -183,9 +180,12 @@ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: else: zq = F.interpolate(zq, size=f.shape[-3:]) + conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y")) + conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b")) + norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f + new_f = norm_f * conv_y + conv_b + return new_f, new_conv_cache class CogVideoXResnetBlock3D(nn.Module): @@ -236,6 +236,7 @@ def __init__( self.out_channels = out_channels self.nonlinearity = get_activation(non_linearity) self.use_conv_shortcut = conv_shortcut + self.spatial_norm_dim = spatial_norm_dim if spatial_norm_dim is None: self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) @@ -279,34 +280,43 @@ def forward( inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: + new_conv_cache = {} + conv_cache = conv_cache or {} + hidden_states = inputs if zq is not None: - hidden_states = self.norm1(hidden_states, zq) + hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1")) else: hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states) + hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1")) if temb is not None: hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] if zq is not None: - hidden_states = self.norm2(hidden_states, zq) + hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2")) else: hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2")) if self.in_channels != self.out_channels: - inputs = self.conv_shortcut(inputs) + if self.use_conv_shortcut: + inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut( + inputs, conv_cache=conv_cache.get("conv_shortcut") + ) + else: + inputs = self.conv_shortcut(inputs) hidden_states = hidden_states + inputs - return hidden_states + return hidden_states, new_conv_cache class CogVideoXDownBlock3D(nn.Module): @@ -392,8 +402,16 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: - for resnet in self.resnets: + r"""Forward method of the `CogVideoXDownBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -402,17 +420,23 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + zq, + conv_cache=conv_cache.get(conv_cache_key), ) else: - hidden_states = resnet(hidden_states, temb, zq) + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - return hidden_states + return hidden_states, new_conv_cache class CogVideoXMidBlock3D(nn.Module): @@ -480,8 +504,16 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: - for resnet in self.resnets: + r"""Forward method of the `CogVideoXMidBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -490,13 +522,15 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) ) else: - hidden_states = resnet(hidden_states, temb, zq) + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) - return hidden_states + return hidden_states, new_conv_cache class CogVideoXUpBlock3D(nn.Module): @@ -584,9 +618,16 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: r"""Forward method of the `CogVideoXUpBlock3D` class.""" - for resnet in self.resnets: + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -595,17 +636,23 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + zq, + conv_cache=conv_cache.get(conv_cache_key), ) else: - hidden_states = resnet(hidden_states, temb, zq) + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) - return hidden_states + return hidden_states, new_conv_cache class CogVideoXEncoder3D(nn.Module): @@ -705,9 +752,18 @@ def __init__( self.gradient_checkpointing = False - def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: r"""The forward method of the `CogVideoXEncoder3D` class.""" - hidden_states = self.conv_in(sample) + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) if self.training and self.gradient_checkpointing: @@ -718,28 +774,44 @@ def custom_forward(*inputs): return custom_forward # 1. Down - for down_block in self.down_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), hidden_states, temb, None + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), + hidden_states, + temb, + None, + conv_cache=conv_cache.get(conv_cache_key), ) # 2. Mid - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, temb, None + hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + hidden_states, + temb, + None, + conv_cache=conv_cache.get("mid_block"), ) else: # 1. Down - for down_block in self.down_blocks: - hidden_states = down_block(hidden_states, temb, None) + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = down_block( + hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key) + ) # 2. Mid - hidden_states = self.mid_block(hidden_states, temb, None) + hidden_states, new_conv_cache["mid_block"] = self.mid_block( + hidden_states, temb, None, conv_cache=conv_cache.get("mid_block") + ) # 3. Post-process hidden_states = self.norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) - return hidden_states + + hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out")) + + return hidden_states, new_conv_cache class CogVideoXDecoder3D(nn.Module): @@ -846,9 +918,18 @@ def __init__( self.gradient_checkpointing = False - def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: r"""The forward method of the `CogVideoXDecoder3D` class.""" - hidden_states = self.conv_in(sample) + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) if self.training and self.gradient_checkpointing: @@ -859,28 +940,45 @@ def custom_forward(*inputs): return custom_forward # 1. Mid - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, temb, sample + hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + hidden_states, + temb, + sample, + conv_cache=conv_cache.get("mid_block"), ) # 2. Up - for up_block in self.up_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), hidden_states, temb, sample + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + hidden_states, + temb, + sample, + conv_cache=conv_cache.get(conv_cache_key), ) else: # 1. Mid - hidden_states = self.mid_block(hidden_states, temb, sample) + hidden_states, new_conv_cache["mid_block"] = self.mid_block( + hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block") + ) # 2. Up - for up_block in self.up_blocks: - hidden_states = up_block(hidden_states, temb, sample) + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = up_block( + hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key) + ) # 3. Post-process - hidden_states = self.norm_out(hidden_states, sample) + hidden_states, new_conv_cache["norm_out"] = self.norm_out( + hidden_states, sample, conv_cache=conv_cache.get("norm_out") + ) hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) - return hidden_states + hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out")) + + return hidden_states, new_conv_cache class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): @@ -1019,12 +1117,6 @@ def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): module.gradient_checkpointing = value - def _clear_fake_context_parallel_cache(self): - for name, module in self.named_modules(): - if isinstance(module, CogVideoXCausalConv3d): - logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") - module._clear_fake_context_parallel_cache() - def enable_tiling( self, tile_sample_min_height: Optional[int] = None, @@ -1091,20 +1183,20 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: frame_batch_size = self.num_sample_frames_batch_size # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 + conv_cache = None enc = [] + for i in range(num_batches): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) end_frame = frame_batch_size * (i + 1) + remaining_frames x_intermediate = x[:, :, start_frame:end_frame] - x_intermediate = self.encoder(x_intermediate) + x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache) if self.quant_conv is not None: x_intermediate = self.quant_conv(x_intermediate) enc.append(x_intermediate) - self._clear_fake_context_parallel_cache() enc = torch.cat(enc, dim=2) - return enc @apply_forward_hook @@ -1143,7 +1235,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut frame_batch_size = self.num_latent_frames_batch_size num_batches = num_frames // frame_batch_size + conv_cache = None dec = [] + for i in range(num_batches): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) @@ -1151,10 +1245,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut z_intermediate = z[:, :, start_frame:end_frame] if self.post_quant_conv is not None: z_intermediate = self.post_quant_conv(z_intermediate) - z_intermediate = self.decoder(z_intermediate) + z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) dec.append(z_intermediate) - self._clear_fake_context_parallel_cache() dec = torch.cat(dec, dim=2) if not return_dict: @@ -1238,7 +1331,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: for j in range(0, width, overlap_width): # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 + conv_cache = None time = [] + for k in range(num_batches): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) @@ -1250,11 +1345,11 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width, ] - tile = self.encoder(tile) + tile, conv_cache = self.encoder(tile, conv_cache=conv_cache) if self.quant_conv is not None: tile = self.quant_conv(tile) time.append(tile) - self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) rows.append(row) @@ -1315,7 +1410,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod row = [] for j in range(0, width, overlap_width): num_batches = num_frames // frame_batch_size + conv_cache = None time = [] + for k in range(num_batches): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) @@ -1329,9 +1426,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod ] if self.post_quant_conv is not None: tile = self.post_quant_conv(tile) - tile = self.decoder(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) time.append(tile) - self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) rows.append(row)