@@ -182,6 +182,7 @@ def _get_glm_embeds(
182
182
prompt : Union [str , List [str ]] = None ,
183
183
num_images_per_prompt : int = 1 ,
184
184
max_sequence_length : int = 1024 ,
185
+ padding_type : str = "longest" ,
185
186
device : Optional [torch .device ] = None ,
186
187
dtype : Optional [torch .dtype ] = None ,
187
188
):
@@ -193,7 +194,7 @@ def _get_glm_embeds(
193
194
194
195
text_inputs = self .tokenizer (
195
196
prompt ,
196
- padding = "longest" , # not use max length
197
+ padding = padding_type ,
197
198
max_length = max_sequence_length ,
198
199
truncation = True ,
199
200
add_special_tokens = True ,
@@ -239,6 +240,7 @@ def encode_prompt(
239
240
device : Optional [torch .device ] = None ,
240
241
dtype : Optional [torch .dtype ] = None ,
241
242
max_sequence_length : int = 1024 ,
243
+ padding_type : str = "longest" ,
242
244
):
243
245
r"""
244
246
Encodes the prompt into text encoder hidden states.
@@ -275,9 +277,8 @@ def encode_prompt(
275
277
batch_size = len (prompt )
276
278
else :
277
279
batch_size = prompt_embeds .shape [0 ]
278
-
279
280
if prompt_embeds is None :
280
- prompt_embeds = self ._get_glm_embeds (prompt , num_images_per_prompt , max_sequence_length , device , dtype )
281
+ prompt_embeds = self ._get_glm_embeds (prompt , num_images_per_prompt , max_sequence_length , padding_type , device , dtype )
281
282
282
283
if do_classifier_free_guidance and negative_prompt_embeds is None :
283
284
negative_prompt = negative_prompt or ""
@@ -296,7 +297,7 @@ def encode_prompt(
296
297
)
297
298
298
299
negative_prompt_embeds = self ._get_glm_embeds (
299
- negative_prompt , num_images_per_prompt , max_sequence_length , device , dtype
300
+ negative_prompt , num_images_per_prompt , max_sequence_length , "longest" , device , dtype
300
301
)
301
302
302
303
return prompt_embeds , negative_prompt_embeds
@@ -450,6 +451,7 @@ def __call__(
450
451
] = None ,
451
452
callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
452
453
max_sequence_length : int = 1024 ,
454
+ padding_type : str = "longest" , # For downstream tasks, it can be modified to use max_length for implementation.
453
455
) -> Union [CogView4PipelineOutput , Tuple ]:
454
456
"""
455
457
Function invoked when calling the pipeline for generation.
@@ -579,7 +581,8 @@ def __call__(
579
581
prompt_embeds = prompt_embeds ,
580
582
negative_prompt_embeds = negative_prompt_embeds ,
581
583
max_sequence_length = max_sequence_length ,
582
- device = device ,
584
+ padding_type = padding_type ,
585
+ device = device
583
586
)
584
587
585
588
# Prepare latents
0 commit comments