diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b45cb2a7950d..198c3ed18070 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -741,10 +741,14 @@ def prepare_attention_mask( if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + attention_mask = attention_mask.repeat_interleave( + head_size, dim=0, output_size=attention_mask.shape[0] * head_size + ) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) - attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + attention_mask = attention_mask.repeat_interleave( + head_size, dim=1, output_size=attention_mask.shape[1] * head_size + ) return attention_mask @@ -3704,8 +3708,10 @@ def __call__( if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads - key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) - value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head) + value = torch.repeat_interleave( + value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head + ) if attn.norm_q is not None: query = attn.norm_q(query) diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 1e6a26dddca8..9146aa5c7c6c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -190,7 +190,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: x = F.pixel_shuffle(x, self.factor) if self.shortcut: - y = hidden_states.repeat_interleave(self.repeats, dim=1) + y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats) y = F.pixel_shuffle(y, self.factor) hidden_states = x + y else: @@ -361,7 +361,9 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.in_shortcut: - x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1) + x = hidden_states.repeat_interleave( + self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats + ) hidden_states = self.conv_in(hidden_states) + x else: hidden_states = self.conv_in(hidden_states) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index f79aabe91dd3..a76277366c09 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -103,7 +103,7 @@ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor: if self.down_sample: identity = hidden_states[:, :, ::2] elif self.up_sample: - identity = hidden_states.repeat_interleave(2, dim=2) + identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2) else: identity = hidden_states diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index cd3eff73ed64..d69ec6252b00 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -426,7 +426,9 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1] # Interleaved repeat of input channels to match w - h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W] + h = inputs.repeat_interleave( + num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs + ) # [B, C * num_freqs, T, H, W] # Scale channels by frequency. h = w * h diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py index 4edc91cacaa7..25348ce606d6 100644 --- a/src/diffusers/models/controlnets/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -687,7 +687,7 @@ def forward( t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) - emb = emb.repeat_interleave(sample_num_frames, dim=0) + emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames) # 2. pre-process batch_size, channels, num_frames, height, width = sample.shape diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 04a0b273f1fa..6dce88826ba0 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed( # 3. Concat pos_embed_spatial = pos_embed_spatial[None, :, :] - pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3] + pos_embed_spatial = pos_embed_spatial.repeat_interleave( + temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size + ) # [T, H*W, D // 4 * 3] pos_embed_temporal = pos_embed_temporal[:, None, :] pos_embed_temporal = pos_embed_temporal.repeat_interleave( @@ -1154,8 +1156,8 @@ def get_1d_rotary_pos_embed( freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox - freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] return freqs_cos, freqs_sin elif use_real: # stable audio, allegro diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 4fe1d99cb6ee..4b359021f29d 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -227,13 +227,17 @@ def forward( # Prepare text embeddings for spatial block # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 - encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view( - -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] - ) + encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave( + num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame + ).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]) # Prepare timesteps for spatial and temporal block - timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1]) - timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1]) + timestep_spatial = timestep.repeat_interleave( + num_frame, dim=0, output_size=timestep.shape[0] * num_frame + ).view(-1, timestep.shape[-1]) + timestep_temp = timestep.repeat_interleave( + num_patches, dim=0, output_size=timestep.shape[0] * num_patches + ).view(-1, timestep.shape[-1]) # Spatial and temporal transformer blocks for i, (spatial_block, temp_block) in enumerate( @@ -299,7 +303,9 @@ def forward( ).permute(0, 2, 1, 3) hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) - embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1]) + embedded_timestep = embedded_timestep.repeat_interleave( + num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame + ).view(-1, embedded_timestep.shape[-1]) shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py index fdb67384ff5e..24d4e4d3d76f 100644 --- a/src/diffusers/models/transformers/prior_transformer.py +++ b/src/diffusers/models/transformers/prior_transformer.py @@ -353,7 +353,11 @@ def forward( attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) - attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) + attention_mask = attention_mask.repeat_interleave( + self.config.num_attention_heads, + dim=0, + output_size=attention_mask.shape[0] * self.config.num_attention_heads, + ) if self.norm_in is not None: hidden_states = self.norm_in(hidden_states) diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 845d93b9db09..a148cf6cbe06 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -638,8 +638,10 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) - emb = emb.repeat_interleave(repeats=num_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames + ) # 2. pre-process sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index f0eca75de169..c275e16744f4 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -592,7 +592,7 @@ def forward( # 3. time + FPS embeddings. emb = t_emb + fps_emb - emb = emb.repeat_interleave(repeats=num_frames, dim=0) + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) # 4. context embeddings. # The context embeddings consist of both text embeddings from the input prompt @@ -620,7 +620,7 @@ def forward( image_emb = self.context_embedding(image_embeddings) image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim) context_emb = torch.cat([context_emb, image_emb], dim=1) - context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0) + context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames) image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape( image_latents.shape[0] * image_latents.shape[2], diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 21e4db23a166..bd83024c9b7c 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -2059,7 +2059,7 @@ def forward( aug_emb = self.add_embedding(add_embeds) emb = emb if aug_emb is None else emb + aug_emb - emb = emb.repeat_interleave(repeats=num_frames, dim=0) + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: @@ -2068,7 +2068,10 @@ def forward( ) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds) - image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds] + image_embeds = [ + image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames) + for image_embed in image_embeds + ] encoder_hidden_states = (encoder_hidden_states, image_embeds) # 2. pre-process diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index db4ace9656a3..059a6e807c8e 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -431,9 +431,11 @@ def forward( sample = sample.flatten(0, 1) # Repeat the embeddings num_video_frames times # emb: [batch, channels] -> [batch * frames, channels] - emb = emb.repeat_interleave(num_frames, dim=0) + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] - encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames + ) # 2. pre-process sample = self.conv_in(sample)