Skip to content

Commit 19d7d27

Browse files
change convert and use GLMModel instead of GLMForCasualLM
1 parent 71f9235 commit 19d7d27

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ def prepare_train_dataset(dataset, accelerator):
660660
[
661661
transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
662662
transforms.ToTensor(),
663-
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
663+
transforms.Lambda(lambda x: x * 2 - 1)
664664
]
665665
)
666666

scripts/convert_cogview4_to_diffusers_megatron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import torch
2727
from tqdm import tqdm
28-
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
28+
from transformers import GlmModel, PreTrainedTokenizerFast
2929

3030
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
3131
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
@@ -326,7 +326,7 @@ def main(args):
326326
# Load the text encoder and tokenizer
327327
text_encoder_id = "THUDM/glm-4-9b-hf"
328328
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
329-
text_encoder = GlmForCausalLM.from_pretrained(
329+
text_encoder = GlmModel.from_pretrained(
330330
text_encoder_id,
331331
cache_dir=args.text_encoder_cache_dir,
332332
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def _get_glm_embeds(
215215
)
216216
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
217217
prompt_embeds = self.text_encoder(
218-
text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True
218+
text_input_ids.to(self.text_encoder.device), output_hidden_states=True
219219
).hidden_states[-2]
220220

221221
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _get_glm_embeds(
219219
)
220220
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
221221
prompt_embeds = self.text_encoder(
222-
text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True
222+
text_input_ids.to(self.text_encoder.device), output_hidden_states=True
223223
).hidden_states[-2]
224224

225225
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

0 commit comments

Comments
 (0)