Skip to content

Commit 3d3ad59

Browse files
committed
vae fix
1 parent 199e741 commit 3d3ad59

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,9 @@ def __init__(
981981
# timestep embedding
982982
self.time_embedder = None
983983
self.scale_shift_table = None
984+
self.timestep_scale_multiplier = None
984985
if timestep_conditioning:
986+
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
985987
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
986988
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
987989

@@ -990,6 +992,9 @@ def __init__(
990992
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
991993
hidden_states = self.conv_in(hidden_states)
992994

995+
if self.timestep_scale_multiplier is not None:
996+
temb = temb * self.timestep_scale_multiplier
997+
993998
if torch.is_grad_enabled() and self.gradient_checkpointing:
994999
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
9951000

0 commit comments

Comments
 (0)