Skip to content

Commit 199e741

Browse files
committed
update
1 parent d0bdf4b commit 199e741

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,17 +223,23 @@ def __init__(
223223

224224
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
225225
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
226-
hidden_states = (
226+
227+
residual = (
227228
hidden_states.unflatten(4, (-1, self.stride[2]))
228229
.unflatten(3, (-1, self.stride[1]))
229230
.unflatten(2, (-1, self.stride[0]))
230231
)
231-
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
232+
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
233+
residual = residual.unflatten(1, (-1, self.group_size))
234+
residual = residual.mean(dim=2)
232235

233-
residual = hidden_states
234-
hidden_states = hidden_states.unflatten(1, (-1, self.group_size))
235-
hidden_states = hidden_states.mean(dim=2)
236236
hidden_states = self.conv(hidden_states)
237+
hidden_states = (
238+
hidden_states.unflatten(4, (-1, self.stride[2]))
239+
.unflatten(3, (-1, self.stride[1]))
240+
.unflatten(2, (-1, self.stride[0]))
241+
)
242+
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
237243
hidden_states = hidden_states + residual
238244

239245
return hidden_states

0 commit comments

Comments
 (0)