18
18
import logging
19
19
import math
20
20
import os
21
+ import random
21
22
import shutil
22
23
from contextlib import nullcontext
23
24
from pathlib import Path
@@ -1094,6 +1095,14 @@ def load_model_hook(models, input_dir):
1094
1095
# TODO: Should a parameter be set here for passing? This is not present in Flux.
1095
1096
crops_coords_top_left = torch .tensor ([(0 , 0 )], dtype = prompt_embeds .dtype , device = prompt_embeds .device )
1096
1097
crops_coords_top_left = crops_coords_top_left .repeat (len (batch ["captions" ]), 1 )
1098
+
1099
+ # this could be optimized by not having to do any text encoding and just
1100
+ # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`
1101
+ if args .proportion_empty_prompts and random .random () < args .proportion_empty_prompts :
1102
+ # 这里,直接将 pooled_prompt_embeds 16个 pad token 提供给 prompt_embeds
1103
+ prompt_embeds = pooled_prompt_embeds
1104
+ if args .offload :
1105
+ text_encoding_pipeline = text_encoding_pipeline .to ("cpu" )
1097
1106
# Predict.
1098
1107
noise_pred_cond = cogview4_transformer (
1099
1108
hidden_states = concatenated_noisy_model_input ,
@@ -1104,17 +1113,6 @@ def load_model_hook(models, input_dir):
1104
1113
crop_coords = crops_coords_top_left ,
1105
1114
return_dict = False ,
1106
1115
)[0 ]
1107
-
1108
- noise_pred_uncond = cogview4_transformer (
1109
- hidden_states = concatenated_noisy_model_input ,
1110
- encoder_hidden_states = pooled_prompt_embeds ,
1111
- timestep = timesteps ,
1112
- original_size = original_size ,
1113
- target_size = target_size ,
1114
- crop_coords = crops_coords_top_left ,
1115
- return_dict = False ,
1116
- )[0 ]
1117
- model_pred = noise_pred_uncond + (noise_pred_cond - noise_pred_uncond )
1118
1116
# these weighting schemes use a uniform timestep sampling
1119
1117
# and instead post-weight the loss
1120
1118
weighting = compute_loss_weighting_for_sd3 (weighting_scheme = args .weighting_scheme , sigmas = sigmas )
@@ -1123,7 +1121,7 @@ def load_model_hook(models, input_dir):
1123
1121
1124
1122
weighting = weighting .view (len (batch ["captions" ]), 1 , 1 , 1 )
1125
1123
loss = torch .mean (
1126
- (weighting .float () * (model_pred .float () - target .float ()) ** 2 ).reshape (target .shape [0 ], - 1 ), 1
1124
+ (weighting .float () * (noise_pred_cond .float () - target .float ()) ** 2 ).reshape (target .shape [0 ], - 1 ), 1
1127
1125
)
1128
1126
loss = loss .mean ()
1129
1127
accelerator .backward (loss )
0 commit comments