From 4b3b8a768cd10b5936f6d2a9dc9bbbc3a59ee1b0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Mar 2023 14:12:48 +0100 Subject: [PATCH 1/7] [Model offload] Add nice warning --- src/diffusers/pipelines/pipeline_utils.py | 19 ++++++++++++ tests/test_pipelines.py | 36 +++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 770fcba15124..a5822ca1a18f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -49,6 +49,7 @@ get_class_from_dynamic_module, http_user_agent, is_accelerate_available, + is_accelerate_version, is_safetensors_available, is_torch_version, is_transformers_available, @@ -66,6 +67,10 @@ from ..utils import FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME +if is_accelerate_available(): + import accelerate + + INDEX_FILE = "diffusion_pytorch_model.bin" CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" DUMMY_MODULES_FOLDER = "diffusers.utils" @@ -335,6 +340,20 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None): if torch_device is None: return self + # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. + def module_is_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): + return False + + return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + + pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) + + if pipeline_is_offloaded and torch.device(torch_device).type == "cuda": + logger.warn( + f"It seems like you have activated model offloading by calling `enable_model_offload` or `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')`." + ) + module_names, _, _ = self.extract_init_dict(dict(self.config)) for name in module_names.keys(): module = getattr(self, name) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index e909e4561347..c1641b4f35f6 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -584,6 +584,42 @@ def test_stable_diffusion_components(self): assert image_img2img.shape == (1, 32, 32, 3) assert image_text2img.shape == (1, 64, 64, 3) + @require_torch_gpu + def test_pipe_false_offload_warn(self): + unet = self.dummy_cond_unet() + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + sd.enable_model_cpu_offload() + + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + with CaptureLogger(logger) as cap_logger: + sd.to("cuda") + + assert "It is strongly recommended against doing so" in str(cap_logger) + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + def test_set_scheduler(self): unet = self.dummy_cond_unet() scheduler = PNDMScheduler(skip_prk_steps=True) From 7e2509e4575e445052eb548dd8dbfbaaeee1810b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 3 Mar 2023 16:02:44 +0100 Subject: [PATCH 2/7] Treat sequential and model offload differently. Sequential raises an error because the operation would fail with a cryptic warning later. --- src/diffusers/pipelines/pipeline_utils.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a5822ca1a18f..103a0ad5835c 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -341,17 +341,30 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None): return self # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. + def module_is_sequentially_offloaded(module): + if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): + return False + + return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + def module_is_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): return False return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer + pipeline_is_sequentially_offloaded = any(module_is_sequentially_offloaded(module) for _, module in self.components.items()) + if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda": + raise ValueError( + "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." + ) + + # Display a warning in this case (the operation succeeds but the benefits are lost) pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) - if pipeline_is_offloaded and torch.device(torch_device).type == "cuda": - logger.warn( - f"It seems like you have activated model offloading by calling `enable_model_offload` or `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')`." + logger.warning( + f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." ) module_names, _, _ = self.extract_init_dict(dict(self.config)) From c0d6ee475a2ca7cb5f5a9af402b3f263efe9fdfb Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 3 Mar 2023 16:19:14 +0100 Subject: [PATCH 3/7] Forcibly move to cpu when offloading. --- src/diffusers/pipelines/pipeline_utils.py | 4 ++-- .../stable_diffusion/pipeline_stable_diffusion.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 103a0ad5835c..0ef9ff8d0708 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -336,7 +336,7 @@ def is_saveable_module(name, value): save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) - def to(self, torch_device: Optional[Union[str, torch.device]] = None): + def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): if torch_device is None: return self @@ -371,7 +371,7 @@ def module_is_offloaded(module): for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): - if module.dtype == torch.float16 and str(torch_device) in ["cpu"]: + if module.dtype == torch.float16 and str(torch_device) in ["cpu"] and not silence_dtype_warnings: logger.warning( "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" " is not recommended to move them to `cpu` as running them will fail. Please make" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 48d13511cd41..2167631434a2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -217,11 +217,15 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" @@ -237,6 +241,10 @@ def enable_model_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) From 5ef393ed7385222a00fd0ea58c817ae98b8d70ba Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Mar 2023 16:23:40 +0100 Subject: [PATCH 4/7] make style --- src/diffusers/pipelines/pipeline_utils.py | 6 ++++-- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0ef9ff8d0708..6fa4c9cec068 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -354,12 +354,14 @@ def module_is_offloaded(module): return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer - pipeline_is_sequentially_offloaded = any(module_is_sequentially_offloaded(module) for _, module in self.components.items()) + pipeline_is_sequentially_offloaded = any( + module_is_sequentially_offloaded(module) for _, module in self.components.items() + ) if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda": raise ValueError( "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." ) - + # Display a warning in this case (the operation succeeds but the benefits are lost) pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) if pipeline_is_offloaded and torch.device(torch_device).type == "cuda": diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 2167631434a2..504479798617 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -219,13 +219,13 @@ def enable_sequential_cpu_offload(self, gpu_id=0): if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" @@ -243,7 +243,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: From 5ebd0241a031951aa4763072ebc2fac25986608e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Mar 2023 16:31:08 +0100 Subject: [PATCH 5/7] one more fix --- src/diffusers/pipelines/pipeline_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2235f1f4c486..65b348d2e7d3 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -376,10 +376,16 @@ def module_is_offloaded(module): ) module_names, _, _ = self.extract_init_dict(dict(self.config)) + is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): - if module.dtype == torch.float16 and str(torch_device) in ["cpu"] and not silence_dtype_warnings: + if ( + module.dtype == torch.float16 + and str(torch_device) in ["cpu"] + and not silence_dtype_warnings + and not is_offloaded + ): logger.warning( "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" " is not recommended to move them to `cpu` as running them will fail. Please make" From b9ed57c2369859e48b869faa62803f214ec02cb0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Mar 2023 16:34:24 +0100 Subject: [PATCH 6/7] make fix-copies --- .../pipelines/alt_diffusion/pipeline_alt_diffusion.py | 8 ++++++++ .../stable_diffusion/pipeline_cycle_diffusion.py | 8 ++++++++ .../pipeline_stable_diffusion_attend_and_excite.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 8 ++++++++ .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 8 ++++++++ .../pipeline_stable_diffusion_inpaint_legacy.py | 8 ++++++++ .../pipeline_stable_diffusion_instruct_pix2pix.py | 8 ++++++++ .../pipeline_stable_diffusion_k_diffusion.py | 8 ++++++++ .../pipeline_stable_diffusion_panorama.py | 4 ++++ .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_sag.py | 4 ++++ 11 files changed, 72 insertions(+) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 2ae3baa74db6..71e98480ed2d 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -214,6 +214,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -234,6 +238,10 @@ def enable_model_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 385c3e09bb1e..e977071b9c6c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -237,6 +237,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -258,6 +262,10 @@ def enable_model_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 2e92c7c31c83..e5550f1d0d45 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -263,6 +263,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index ebc252430d23..172ab15a757e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -225,6 +225,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -246,6 +250,10 @@ def enable_model_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b9b207b6064e..b645ba667f77 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -272,6 +272,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -293,6 +297,10 @@ def enable_model_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index a89d6814caef..1e84efa4163c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -216,6 +216,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -237,6 +241,10 @@ def enable_model_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 98204734714e..9efa91e161b9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -405,6 +405,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -426,6 +430,10 @@ def enable_model_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index f668d152a7f0..f3db54caa342 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -137,6 +137,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -158,6 +162,10 @@ def enable_model_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 5f0ca6f67e66..b1f29fbef12b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -158,6 +158,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 101acf0ad932..c78d327f69a9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -372,6 +372,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 4ba3345f9442..0a34499e090c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -176,6 +176,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) From 78c2413b72cc23d0e1ed674880182fb3703afa58 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Mar 2023 16:38:19 +0100 Subject: [PATCH 7/7] up --- .../alt_diffusion/pipeline_alt_diffusion_img2img.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 1b30ceddca65..1e7872e3b081 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -220,6 +220,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) @@ -240,6 +244,10 @@ def enable_model_cpu_offload(self, gpu_id=0): device = torch.device(f"cuda:{gpu_id}") + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)