Skip to content

Commit 90830ed

Browse files
remove the uncond
1 parent 65b3719 commit 90830ed

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
import math
2020
import os
21+
import random
2122
import shutil
2223
from contextlib import nullcontext
2324
from pathlib import Path
@@ -1094,6 +1095,14 @@ def load_model_hook(models, input_dir):
10941095
# TODO: Should a parameter be set here for passing? This is not present in Flux.
10951096
crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device)
10961097
crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1)
1098+
1099+
# this could be optimized by not having to do any text encoding and just
1100+
# doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`
1101+
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
1102+
# 这里,直接将 pooled_prompt_embeds 16个 pad token 提供给 prompt_embeds
1103+
prompt_embeds = pooled_prompt_embeds
1104+
if args.offload:
1105+
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
10971106
# Predict.
10981107
noise_pred_cond = cogview4_transformer(
10991108
hidden_states=concatenated_noisy_model_input,
@@ -1104,17 +1113,6 @@ def load_model_hook(models, input_dir):
11041113
crop_coords=crops_coords_top_left,
11051114
return_dict=False,
11061115
)[0]
1107-
1108-
noise_pred_uncond = cogview4_transformer(
1109-
hidden_states=concatenated_noisy_model_input,
1110-
encoder_hidden_states=pooled_prompt_embeds,
1111-
timestep=timesteps,
1112-
original_size=original_size,
1113-
target_size=target_size,
1114-
crop_coords=crops_coords_top_left,
1115-
return_dict=False,
1116-
)[0]
1117-
model_pred = noise_pred_uncond + (noise_pred_cond - noise_pred_uncond)
11181116
# these weighting schemes use a uniform timestep sampling
11191117
# and instead post-weight the loss
11201118
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
@@ -1123,7 +1121,7 @@ def load_model_hook(models, input_dir):
11231121

11241122
weighting = weighting.view(len(batch["captions"]), 1, 1, 1)
11251123
loss = torch.mean(
1126-
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1
1124+
(weighting.float() * (noise_pred_cond.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1
11271125
)
11281126
loss = loss.mean()
11291127
accelerator.backward(loss)

0 commit comments

Comments
 (0)