diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index c27a1a19774d..8550fa94f9e4 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -213,9 +213,7 @@ def _get_glm_embeds( device=text_input_ids.device, ) text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) - prompt_embeds = self.text_encoder( - text_input_ids.to(self.text_encoder.device), output_hidden_states=True - ).hidden_states[-2] + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index 92b138b7af95..7613bc3d0f40 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -216,9 +216,7 @@ def _get_glm_embeds( device=text_input_ids.device, ) text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) - prompt_embeds = self.text_encoder( - text_input_ids.to(self.text_encoder.device), output_hidden_states=True - ).hidden_states[-2] + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds