Skip to content

Commit 71f9235

Browse files
add max length
1 parent 90830ed commit 71f9235

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step,
132132
control_image=validation_image,
133133
num_inference_steps=50,
134134
guidance_scale=args.guidance_scale,
135+
max_sequence_length=max_sequence_length, # For downstream task training usage, training can be performed on a batch basis.
136+
padding_type="max_length",
135137
generator=generator,
136138
height=args.resolution,
137139
width=args.resolution,
@@ -267,6 +269,9 @@ def parse_args(input_args=None):
267269
" resolution"
268270
),
269271
)
272+
parser.add_argument(
273+
"--max_sequence_length", type=int, default=128, help="The maximum sequence length for the prompt."
274+
)
270275
parser.add_argument(
271276
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
272277
)
@@ -1079,10 +1084,12 @@ def load_model_hook(models, input_dir):
10791084
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
10801085

10811086
with torch.no_grad():
1082-
(
1083-
prompt_embeds,
1084-
pooled_prompt_embeds,
1085-
) = text_encoding_pipeline.encode_prompt(captions, "")
1087+
# Since the batch will be padded, max_length should be used for padding.
1088+
prompt_embeds,pooled_prompt_embeds,= text_encoding_pipeline.encode_prompt(
1089+
captions, "",
1090+
max_sequence_length=args.max_sequence_length,
1091+
padding_type="max_length"
1092+
)
10861093
original_size = (args.resolution, args.resolution)
10871094
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device)
10881095

@@ -1099,7 +1106,7 @@ def load_model_hook(models, input_dir):
10991106
# this could be optimized by not having to do any text encoding and just
11001107
# doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`
11011108
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
1102-
# 这里,直接将 pooled_prompt_embeds 16个 pad token 提供给 prompt_embeds
1109+
# Here, we directly pass 16 pad tokens from pooled_prompt_embeds to prompt_embeds.
11031110
prompt_embeds = pooled_prompt_embeds
11041111
if args.offload:
11051112
text_encoding_pipeline = text_encoding_pipeline.to("cpu")

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def _get_glm_embeds(
182182
prompt: Union[str, List[str]] = None,
183183
num_images_per_prompt: int = 1,
184184
max_sequence_length: int = 1024,
185+
padding_type: str = "longest",
185186
device: Optional[torch.device] = None,
186187
dtype: Optional[torch.dtype] = None,
187188
):
@@ -193,7 +194,7 @@ def _get_glm_embeds(
193194

194195
text_inputs = self.tokenizer(
195196
prompt,
196-
padding="longest", # not use max length
197+
padding=padding_type,
197198
max_length=max_sequence_length,
198199
truncation=True,
199200
add_special_tokens=True,
@@ -239,6 +240,7 @@ def encode_prompt(
239240
device: Optional[torch.device] = None,
240241
dtype: Optional[torch.dtype] = None,
241242
max_sequence_length: int = 1024,
243+
padding_type: str = "longest",
242244
):
243245
r"""
244246
Encodes the prompt into text encoder hidden states.
@@ -275,9 +277,8 @@ def encode_prompt(
275277
batch_size = len(prompt)
276278
else:
277279
batch_size = prompt_embeds.shape[0]
278-
279280
if prompt_embeds is None:
280-
prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, device, dtype)
281+
prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, padding_type, device, dtype)
281282

282283
if do_classifier_free_guidance and negative_prompt_embeds is None:
283284
negative_prompt = negative_prompt or ""
@@ -296,7 +297,7 @@ def encode_prompt(
296297
)
297298

298299
negative_prompt_embeds = self._get_glm_embeds(
299-
negative_prompt, num_images_per_prompt, max_sequence_length, device, dtype
300+
negative_prompt, num_images_per_prompt, max_sequence_length, "longest", device, dtype
300301
)
301302

302303
return prompt_embeds, negative_prompt_embeds
@@ -450,6 +451,7 @@ def __call__(
450451
] = None,
451452
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
452453
max_sequence_length: int = 1024,
454+
padding_type: str = "longest", # For downstream tasks, it can be modified to use max_length for implementation.
453455
) -> Union[CogView4PipelineOutput, Tuple]:
454456
"""
455457
Function invoked when calling the pipeline for generation.
@@ -579,7 +581,8 @@ def __call__(
579581
prompt_embeds=prompt_embeds,
580582
negative_prompt_embeds=negative_prompt_embeds,
581583
max_sequence_length=max_sequence_length,
582-
device=device,
584+
padding_type=padding_type,
585+
device=device
583586
)
584587

585588
# Prepare latents

0 commit comments

Comments
 (0)