diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 66b56740ef13..381def950169 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -65,7 +65,7 @@ numpy_to_pil, ) from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card -from ..utils.torch_utils import is_compiled_module +from ..utils.torch_utils import get_device, is_compiled_module if is_torch_npu_available(): @@ -1084,19 +1084,20 @@ def remove_all_hooks(self): accelerate.hooks.remove_hook_from_module(model, recurse=True) self._all_hooks = [] - def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its + `forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are + lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution + of the `unet`. Arguments: gpu_id (`int`, *optional*): The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. - device (`torch.Device` or `str`, *optional*, defaults to "cuda"): + device (`torch.Device` or `str`, *optional*, defaults to None): The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will - default to "cuda". + automatically detect the available accelerator and use. """ self._maybe_raise_error_if_group_offload_active(raise_error=True) @@ -1118,6 +1119,11 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t self.remove_all_hooks() + if device is None: + device = get_device() + if device == "cpu": + raise RuntimeError("`enable_model_cpu_offload` requires accelerator, but not found") + torch_device = torch.device(device) device_index = torch_device.index @@ -1196,20 +1202,20 @@ def maybe_free_model_hooks(self): # make sure the model is in the same state as before calling it self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda")) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU - and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward` - method called. Offloading happens on a submodule basis. Memory savings are higher than with + and then moved to `torch.device('meta')` and loaded to accelerator only when their specific submodule has its + `forward` method called. Offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. Arguments: gpu_id (`int`, *optional*): The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. - device (`torch.Device` or `str`, *optional*, defaults to "cuda"): + device (`torch.Device` or `str`, *optional*, defaults to None): The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will - default to "cuda". + automatically detect the available accelerator and use. """ self._maybe_raise_error_if_group_offload_active(raise_error=True) @@ -1225,6 +1231,11 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`." ) + if device is None: + device = get_device() + if device == "cpu": + raise RuntimeError("`enable_sequential_cpu_offload` requires accelerator, but not found") + torch_device = torch.device(device) device_index = torch_device.index diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 3c8911773e39..a5df07e4a3c2 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -159,3 +159,12 @@ def get_torch_cuda_device_capability(): return float(compute_capability) else: return None + + +def get_device(): + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return "xpu" + else: + return "cpu" diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 00c7636ed9fd..caa7755904a5 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1816,7 +1816,12 @@ def test_pipe_same_device_id_offload(self): feature_extractor=self.dummy_extractor, ) - sd.enable_model_cpu_offload(gpu_id=5) + # `enable_model_cpu_offload` detects device type when not passed + # `enable_model_cpu_offload` raises ValueError if detected device is `cpu` + # This test only checks whether `_offload_gpu_id` is set correctly + # So the device passed can be any supported `torch.device` type + # This allows us to keep the test under `PipelineFastTests` + sd.enable_model_cpu_offload(gpu_id=5, device="cuda") assert sd._offload_gpu_id == 5 sd.maybe_free_model_hooks() assert sd._offload_gpu_id == 5