File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed
src/diffusers/models/autoencoders Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -981,7 +981,9 @@ def __init__(
981
981
# timestep embedding
982
982
self .time_embedder = None
983
983
self .scale_shift_table = None
984
+ self .timestep_scale_multiplier = None
984
985
if timestep_conditioning :
986
+ self .timestep_scale_multiplier = nn .Parameter (torch .tensor (1000.0 , dtype = torch .float32 ))
985
987
self .time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings (output_channel * 2 , 0 )
986
988
self .scale_shift_table = nn .Parameter (torch .randn (2 , output_channel ) / output_channel ** 0.5 )
987
989
@@ -990,6 +992,9 @@ def __init__(
990
992
def forward (self , hidden_states : torch .Tensor , temb : Optional [torch .Tensor ] = None ) -> torch .Tensor :
991
993
hidden_states = self .conv_in (hidden_states )
992
994
995
+ if self .timestep_scale_multiplier is not None :
996
+ temb = temb * self .timestep_scale_multiplier
997
+
993
998
if torch .is_grad_enabled () and self .gradient_checkpointing :
994
999
hidden_states = self ._gradient_checkpointing_func (self .mid_block , hidden_states , temb )
995
1000
You can’t perform that action at this time.
0 commit comments