Skip to content

Commit 8b4f8ba

Browse files
authored
Use output_size in repeat_interleave (#11030)
1 parent 5428046 commit 8b4f8ba

12 files changed

+56
-27
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -741,10 +741,14 @@ def prepare_attention_mask(
741741

742742
if out_dim == 3:
743743
if attention_mask.shape[0] < batch_size * head_size:
744-
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
744+
attention_mask = attention_mask.repeat_interleave(
745+
head_size, dim=0, output_size=attention_mask.shape[0] * head_size
746+
)
745747
elif out_dim == 4:
746748
attention_mask = attention_mask.unsqueeze(1)
747-
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
749+
attention_mask = attention_mask.repeat_interleave(
750+
head_size, dim=1, output_size=attention_mask.shape[1] * head_size
751+
)
748752

749753
return attention_mask
750754

@@ -3704,8 +3708,10 @@ def __call__(
37043708
if kv_heads != attn.heads:
37053709
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
37063710
heads_per_kv_head = attn.heads // kv_heads
3707-
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
3708-
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
3711+
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
3712+
value = torch.repeat_interleave(
3713+
value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
3714+
)
37093715

37103716
if attn.norm_q is not None:
37113717
query = attn.norm_q(query)

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
190190
x = F.pixel_shuffle(x, self.factor)
191191

192192
if self.shortcut:
193-
y = hidden_states.repeat_interleave(self.repeats, dim=1)
193+
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
194194
y = F.pixel_shuffle(y, self.factor)
195195
hidden_states = x + y
196196
else:
@@ -361,7 +361,9 @@ def __init__(
361361

362362
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
363363
if self.in_shortcut:
364-
x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1)
364+
x = hidden_states.repeat_interleave(
365+
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
366+
)
365367
hidden_states = self.conv_in(hidden_states) + x
366368
else:
367369
hidden_states = self.conv_in(hidden_states)

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor:
103103
if self.down_sample:
104104
identity = hidden_states[:, :, ::2]
105105
elif self.up_sample:
106-
identity = hidden_states.repeat_interleave(2, dim=2)
106+
identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2)
107107
else:
108108
identity = hidden_states
109109

src/diffusers/models/autoencoders/autoencoder_kl_mochi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,9 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
426426
w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1]
427427

428428
# Interleaved repeat of input channels to match w
429-
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
429+
h = inputs.repeat_interleave(
430+
num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs
431+
) # [B, C * num_freqs, T, H, W]
430432
# Scale channels by frequency.
431433
h = w * h
432434

src/diffusers/models/controlnets/controlnet_sparsectrl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def forward(
687687
t_emb = t_emb.to(dtype=sample.dtype)
688688

689689
emb = self.time_embedding(t_emb, timestep_cond)
690-
emb = emb.repeat_interleave(sample_num_frames, dim=0)
690+
emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames)
691691

692692
# 2. pre-process
693693
batch_size, channels, num_frames, height, width = sample.shape

src/diffusers/models/embeddings.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed(
139139

140140
# 3. Concat
141141
pos_embed_spatial = pos_embed_spatial[None, :, :]
142-
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
142+
pos_embed_spatial = pos_embed_spatial.repeat_interleave(
143+
temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size
144+
) # [T, H*W, D // 4 * 3]
143145

144146
pos_embed_temporal = pos_embed_temporal[:, None, :]
145147
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
@@ -1154,8 +1156,8 @@ def get_1d_rotary_pos_embed(
11541156
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
11551157
if use_real and repeat_interleave_real:
11561158
# flux, hunyuan-dit, cogvideox
1157-
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
1158-
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
1159+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
1160+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
11591161
return freqs_cos, freqs_sin
11601162
elif use_real:
11611163
# stable audio, allegro

src/diffusers/models/transformers/latte_transformer_3d.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,17 @@ def forward(
227227
# Prepare text embeddings for spatial block
228228
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
229229
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
230-
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
231-
-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
232-
)
230+
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
231+
num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
232+
).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
233233

234234
# Prepare timesteps for spatial and temporal block
235-
timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
236-
timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
235+
timestep_spatial = timestep.repeat_interleave(
236+
num_frame, dim=0, output_size=timestep.shape[0] * num_frame
237+
).view(-1, timestep.shape[-1])
238+
timestep_temp = timestep.repeat_interleave(
239+
num_patches, dim=0, output_size=timestep.shape[0] * num_patches
240+
).view(-1, timestep.shape[-1])
237241

238242
# Spatial and temporal transformer blocks
239243
for i, (spatial_block, temp_block) in enumerate(
@@ -299,7 +303,9 @@ def forward(
299303
).permute(0, 2, 1, 3)
300304
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
301305

302-
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
306+
embedded_timestep = embedded_timestep.repeat_interleave(
307+
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
308+
).view(-1, embedded_timestep.shape[-1])
303309
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
304310
hidden_states = self.norm_out(hidden_states)
305311
# Modulation

src/diffusers/models/transformers/prior_transformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,11 @@ def forward(
353353
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
354354
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
355355
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
356-
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
356+
attention_mask = attention_mask.repeat_interleave(
357+
self.config.num_attention_heads,
358+
dim=0,
359+
output_size=attention_mask.shape[0] * self.config.num_attention_heads,
360+
)
357361

358362
if self.norm_in is not None:
359363
hidden_states = self.norm_in(hidden_states)

src/diffusers/models/unets/unet_3d_condition.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,8 +638,10 @@ def forward(
638638
t_emb = t_emb.to(dtype=self.dtype)
639639

640640
emb = self.time_embedding(t_emb, timestep_cond)
641-
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
642-
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
641+
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
642+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
643+
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
644+
)
643645

644646
# 2. pre-process
645647
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])

src/diffusers/models/unets/unet_i2vgen_xl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def forward(
592592

593593
# 3. time + FPS embeddings.
594594
emb = t_emb + fps_emb
595-
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
595+
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
596596

597597
# 4. context embeddings.
598598
# The context embeddings consist of both text embeddings from the input prompt
@@ -620,7 +620,7 @@ def forward(
620620
image_emb = self.context_embedding(image_embeddings)
621621
image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim)
622622
context_emb = torch.cat([context_emb, image_emb], dim=1)
623-
context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
623+
context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames)
624624

625625
image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
626626
image_latents.shape[0] * image_latents.shape[2],

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,7 +2059,7 @@ def forward(
20592059
aug_emb = self.add_embedding(add_embeds)
20602060

20612061
emb = emb if aug_emb is None else emb + aug_emb
2062-
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
2062+
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
20632063

20642064
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
20652065
if "image_embeds" not in added_cond_kwargs:
@@ -2068,7 +2068,10 @@ def forward(
20682068
)
20692069
image_embeds = added_cond_kwargs.get("image_embeds")
20702070
image_embeds = self.encoder_hid_proj(image_embeds)
2071-
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
2071+
image_embeds = [
2072+
image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames)
2073+
for image_embed in image_embeds
2074+
]
20722075
encoder_hidden_states = (encoder_hidden_states, image_embeds)
20732076

20742077
# 2. pre-process

src/diffusers/models/unets/unet_spatio_temporal_condition.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,11 @@ def forward(
431431
sample = sample.flatten(0, 1)
432432
# Repeat the embeddings num_video_frames times
433433
# emb: [batch, channels] -> [batch * frames, channels]
434-
emb = emb.repeat_interleave(num_frames, dim=0)
434+
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
435435
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
436-
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
436+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
437+
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
438+
)
437439

438440
# 2. pre-process
439441
sample = self.conv_in(sample)

0 commit comments

Comments
 (0)