From 807c8ac9f98b9c65102e01905d49c11a066f1f83 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 2 Apr 2025 10:21:32 +0100 Subject: [PATCH 1/2] Fix enable_sequential_cpu_offload in CogView4Pipeline --- src/diffusers/pipelines/cogview4/pipeline_cogview4.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index c27a1a19774d..8550fa94f9e4 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -213,9 +213,7 @@ def _get_glm_embeds( device=text_input_ids.device, ) text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) - prompt_embeds = self.text_encoder( - text_input_ids.to(self.text_encoder.device), output_hidden_states=True - ).hidden_states[-2] + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds From 4f38ee22f47105b4e0ec6236c747e2645be32cba Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 2 Apr 2025 10:45:39 +0100 Subject: [PATCH 2/2] make fix-copies --- src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index 92b138b7af95..7613bc3d0f40 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -216,9 +216,7 @@ def _get_glm_embeds( device=text_input_ids.device, ) text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) - prompt_embeds = self.text_encoder( - text_input_ids.to(self.text_encoder.device), output_hidden_states=True - ).hidden_states[-2] + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds