diff --git a/docs/source/en/api/utilities.md b/docs/source/en/api/utilities.md index b0b78928fb4b..b653cdafbb28 100644 --- a/docs/source/en/api/utilities.md +++ b/docs/source/en/api/utilities.md @@ -45,3 +45,7 @@ Utility and helper functions for working with 🤗 Diffusers. ## apply_layerwise_casting [[autodoc]] hooks.layerwise_casting.apply_layerwise_casting + +## apply_group_offloading + +[[autodoc]] hooks.group_offloading.apply_group_offloading diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index 4cdc60401914..9467a770d484 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -158,6 +158,46 @@ In order to properly offload models after they're called, it is required to run +## Group offloading + +Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced. + +To enable group offloading, call the [`~ModelMixin.enable_group_offload`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]: + +```python +import torch +from diffusers import CogVideoXPipeline +from diffusers.hooks import apply_group_offloading +from diffusers.utils import export_to_video + +# Load the pipeline +onload_device = torch.device("cuda") +offload_device = torch.device("cpu") +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + +# We can utilize the enable_group_offload method for Diffusers model implementations +pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) + +# For any other model implementations, the apply_group_offloading function can be used +apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2) +apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level") + +prompt = ( + "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + "atmosphere of this unique musical performance." +) +video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] +# This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline. +print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") +export_to_video(video, "output.mp4", fps=8) +``` + +Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams. + ## FP8 layerwise weight-casting PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting. diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index e745b1320e84..56be0bbdf305 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -2,6 +2,7 @@ if is_torch_available(): + from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py new file mode 100644 index 000000000000..c389c5dc9826 --- /dev/null +++ b/src/diffusers/hooks/group_offloading.py @@ -0,0 +1,678 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +from typing import Dict, List, Optional, Set, Tuple + +import torch + +from ..utils import get_logger, is_accelerate_available +from .hooks import HookRegistry, ModelHook + + +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload + from accelerate.utils import send_to_device + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +# fmt: off +_GROUP_OFFLOADING = "group_offloading" +_LAYER_EXECUTION_TRACKER = "layer_execution_tracker" +_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" + +_SUPPORTED_PYTORCH_LAYERS = ( + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, + torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, + torch.nn.Linear, + # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX + # because of double invocation of the same norm layer in CogVideoXLayerNorm +) +# fmt: on + + +class ModuleGroup: + def __init__( + self, + modules: List[torch.nn.Module], + offload_device: torch.device, + onload_device: torch.device, + offload_leader: torch.nn.Module, + onload_leader: Optional[torch.nn.Module] = None, + parameters: Optional[List[torch.nn.Parameter]] = None, + buffers: Optional[List[torch.Tensor]] = None, + non_blocking: bool = False, + stream: Optional[torch.cuda.Stream] = None, + cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, + onload_self: bool = True, + ) -> None: + self.modules = modules + self.offload_device = offload_device + self.onload_device = onload_device + self.offload_leader = offload_leader + self.onload_leader = onload_leader + self.parameters = parameters + self.buffers = buffers + self.non_blocking = non_blocking or stream is not None + self.stream = stream + self.cpu_param_dict = cpu_param_dict + self.onload_self = onload_self + + if self.stream is not None and self.cpu_param_dict is None: + raise ValueError("cpu_param_dict must be provided when using stream for data transfer.") + + def onload_(self): + r"""Onloads the group of modules to the onload_device.""" + context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) + if self.stream is not None: + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() + + with context: + for group_module in self.modules: + group_module.to(self.onload_device, non_blocking=self.non_blocking) + if self.parameters is not None: + for param in self.parameters: + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + if self.buffers is not None: + for buffer in self.buffers: + buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + + def offload_(self): + r"""Offloads the group of modules to the offload_device.""" + if self.stream is not None: + torch.cuda.current_stream().synchronize() + for group_module in self.modules: + for param in group_module.parameters(): + param.data = self.cpu_param_dict[param] + else: + for group_module in self.modules: + group_module.to(self.offload_device, non_blocking=self.non_blocking) + if self.parameters is not None: + for param in self.parameters: + param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) + if self.buffers is not None: + for buffer in self.buffers: + buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) + + +class GroupOffloadingHook(ModelHook): + r""" + A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for + computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader" + module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module + group is responsible for onloading the current module group. + """ + + _is_stateful = False + + def __init__( + self, + group: ModuleGroup, + next_group: Optional[ModuleGroup] = None, + ) -> None: + self.group = group + self.next_group = next_group + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + if self.group.offload_leader == module: + self.group.offload_() + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + # If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward + # method is the onload_leader of the group. + if self.group.onload_leader is None: + self.group.onload_leader = module + + # If the current module is the onload_leader of the group, we onload the group if it is supposed + # to onload itself. In the case of using prefetching with streams, we onload the next group if + # it is not supposed to onload itself. + if self.group.onload_leader == module: + if self.group.onload_self: + self.group.onload_() + if self.next_group is not None and not self.next_group.onload_self: + self.next_group.onload_() + + args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) + kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output): + if self.group.offload_leader == module: + self.group.offload_() + return output + + +class LazyPrefetchGroupOffloadingHook(ModelHook): + r""" + A hook, used in conjuction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module. + This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer + invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows + prefetching groups in the correct order. + """ + + _is_stateful = False + + def __init__(self): + self.execution_order: List[Tuple[str, torch.nn.Module]] = [] + self._layer_execution_tracker_module_names = set() + + def initialize_hook(self, module): + # To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any + # of the groups), we add a layer execution tracker hook that will be used to determine the order in which the + # layers are executed during the forward pass. + for name, submodule in module.named_modules(): + if name == "" or not hasattr(submodule, "_diffusers_hook"): + continue + + registry = HookRegistry.check_if_exists_or_initialize(submodule) + group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING) + + if group_offloading_hook is not None: + + def make_execution_order_update_callback(current_name, current_submodule): + def callback(): + logger.debug(f"Adding {current_name} to the execution order") + self.execution_order.append((current_name, current_submodule)) + + return callback + + layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule)) + registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER) + self._layer_execution_tracker_module_names.add(name) + + return module + + def post_forward(self, module, output): + # At this point, for the current modules' submodules, we know the execution order of the layers. We can now + # remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each + # group offloading hook. + num_executed = len(self.execution_order) + execution_order_module_names = {name for name, _ in self.execution_order} + + # It may be possible that some layers were not executed during the forward pass. This can happen if the layer + # is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we + # may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors + # if the missing layers end up being executed in the future. + if execution_order_module_names != self._layer_execution_tracker_module_names: + unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names) + logger.warning( + "It seems like some layers were not executed during the forward pass. This may lead to problems when " + "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please " + "make sure that all layers are executed during the forward pass. The following layers were not executed:\n" + f"{unexecuted_layers=}" + ) + + # Remove the layer execution tracker hooks from the submodules + base_module_registry = module._diffusers_hook + registries = [submodule._diffusers_hook for _, submodule in self.execution_order] + + for i in range(num_executed): + registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False) + + # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass + base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False) + + # Apply lazy prefetching by setting required attributes + group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] + if num_executed > 0: + base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING) + base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group + base_module_group_offloading_hook.next_group.onload_self = False + + for i in range(num_executed - 1): + name1, _ = self.execution_order[i] + name2, _ = self.execution_order[i + 1] + logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}") + group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group + group_offloading_hooks[i].next_group.onload_self = False + + return output + + +class LayerExecutionTrackerHook(ModelHook): + r""" + A hook that tracks the order in which the layers are executed during the forward pass by calling back to the + LazyPrefetchGroupOffloadingHook to update the execution order. + """ + + _is_stateful = False + + def __init__(self, execution_order_update_callback): + self.execution_order_update_callback = execution_order_update_callback + + def pre_forward(self, module, *args, **kwargs): + self.execution_order_update_callback() + return args, kwargs + + +def apply_group_offloading( + module: torch.nn.Module, + onload_device: torch.device, + offload_device: torch.device = torch.device("cpu"), + offload_type: str = "block_level", + num_blocks_per_group: Optional[int] = None, + non_blocking: bool = False, + use_stream: bool = False, +) -> None: + r""" + Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and + where it is beneficial, we need to first provide some context on how other supported offloading methods work. + + Typically, offloading is done at two levels: + - Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It + works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device + when needed for computation. This method is more memory-efficient than keeping all components on the accelerator, + but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of + the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward + pass. + - Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It + works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and + onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator + memory, but can be slower due to the excessive number of device synchronizations. + + Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers, + (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level + offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is + reduced. + + Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to + overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This + is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to + the accelerator device while the current layer is being executed - this increases the memory requirements slightly. + Note that this implementation also supports leaf-level offloading but can be made much faster when using streams. + + Args: + module (`torch.nn.Module`): + The module to which group offloading is applied. + onload_device (`torch.device`): + The device to which the group of modules are onloaded. + offload_device (`torch.device`, defaults to `torch.device("cpu")`): + The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU. + offload_type (`str`, defaults to "block_level"): + The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is + "block_level". + num_blocks_per_group (`int`, *optional*): + The number of blocks per group when using offload_type="block_level". This is required when using + offload_type="block_level". + non_blocking (`bool`, defaults to `False`): + If True, offloading and onloading is done with non-blocking data transfer. + use_stream (`bool`, defaults to `False`): + If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for + overlapping computation and data transfer. + + Example: + ```python + >>> from diffusers import CogVideoXTransformer3DModel + >>> from diffusers.hooks import apply_group_offloading + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained( + ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + + >>> apply_group_offloading( + ... transformer, + ... onload_device=torch.device("cuda"), + ... offload_device=torch.device("cpu"), + ... offload_type="block_level", + ... num_blocks_per_group=2, + ... use_stream=True, + ... ) + ``` + """ + + stream = None + if use_stream: + if torch.cuda.is_available(): + stream = torch.cuda.Stream() + else: + raise ValueError("Using streams for data transfer requires a CUDA device.") + + _raise_error_if_accelerate_model_or_sequential_hook_present(module) + + if offload_type == "block_level": + if num_blocks_per_group is None: + raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") + + _apply_group_offloading_block_level( + module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream + ) + elif offload_type == "leaf_level": + _apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream) + else: + raise ValueError(f"Unsupported offload_type: {offload_type}") + + +def _apply_group_offloading_block_level( + module: torch.nn.Module, + num_blocks_per_group: int, + offload_device: torch.device, + onload_device: torch.device, + non_blocking: bool, + stream: Optional[torch.cuda.Stream] = None, +) -> None: + r""" + This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to + the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. + + Args: + module (`torch.nn.Module`): + The module to which group offloading is applied. + offload_device (`torch.device`): + The device to which the group of modules are offloaded. This should typically be the CPU. + onload_device (`torch.device`): + The device to which the group of modules are onloaded. + non_blocking (`bool`): + If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation + and data transfer. + stream (`torch.cuda.Stream`, *optional*): + If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful + for overlapping computation and data transfer. + """ + + # Create a pinned CPU parameter dict for async data transfer if streams are to be used + cpu_param_dict = None + if stream is not None: + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict = {param: param.data for param in module.parameters()} + + # Create module groups for ModuleList and Sequential blocks + modules_with_group_offloading = set() + unmatched_modules = [] + matched_module_groups = [] + for name, submodule in module.named_children(): + if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + unmatched_modules.append((name, submodule)) + modules_with_group_offloading.add(name) + continue + + for i in range(0, len(submodule), num_blocks_per_group): + current_modules = submodule[i : i + num_blocks_per_group] + group = ModuleGroup( + modules=current_modules, + offload_device=offload_device, + onload_device=onload_device, + offload_leader=current_modules[-1], + onload_leader=current_modules[0], + non_blocking=non_blocking, + stream=stream, + cpu_param_dict=cpu_param_dict, + onload_self=stream is None, + ) + matched_module_groups.append(group) + for j in range(i, i + len(current_modules)): + modules_with_group_offloading.add(f"{name}.{j}") + + # Apply group offloading hooks to the module groups + for i, group in enumerate(matched_module_groups): + next_group = ( + matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None + ) + + for group_module in group.modules: + _apply_group_offloading_hook(group_module, group, next_group) + + # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately + # when the forward pass of this module is called. This is because the top-level module is not + # part of any group (as doing so would lead to no VRAM savings). + parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) + buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) + parameters = [param for _, param in parameters] + buffers = [buffer for _, buffer in buffers] + + # Create a group for the unmatched submodules of the top-level module so that they are on the correct + # device when the forward pass is called. + unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] + unmatched_group = ModuleGroup( + modules=unmatched_modules, + offload_device=offload_device, + onload_device=onload_device, + offload_leader=module, + onload_leader=module, + parameters=parameters, + buffers=buffers, + non_blocking=False, + stream=None, + cpu_param_dict=None, + onload_self=True, + ) + next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None + _apply_group_offloading_hook(module, unmatched_group, next_group) + + +def _apply_group_offloading_leaf_level( + module: torch.nn.Module, + offload_device: torch.device, + onload_device: torch.device, + non_blocking: bool, + stream: Optional[torch.cuda.Stream] = None, +) -> None: + r""" + This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory + requirements. However, it can be slower compared to other offloading methods due to the excessive number of device + synchronizations. When using devices that support streams to overlap data transfer and computation, this method can + reduce memory usage without any performance degradation. + + Args: + module (`torch.nn.Module`): + The module to which group offloading is applied. + offload_device (`torch.device`): + The device to which the group of modules are offloaded. This should typically be the CPU. + onload_device (`torch.device`): + The device to which the group of modules are onloaded. + non_blocking (`bool`): + If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation + and data transfer. + stream (`torch.cuda.Stream`, *optional*): + If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful + for overlapping computation and data transfer. + """ + + # Create a pinned CPU parameter dict for async data transfer if streams are to be used + cpu_param_dict = None + if stream is not None: + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict = {param: param.data for param in module.parameters()} + + # Create module groups for leaf modules and apply group offloading hooks + modules_with_group_offloading = set() + for name, submodule in module.named_modules(): + if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): + continue + group = ModuleGroup( + modules=[submodule], + offload_device=offload_device, + onload_device=onload_device, + offload_leader=submodule, + onload_leader=submodule, + non_blocking=non_blocking, + stream=stream, + cpu_param_dict=cpu_param_dict, + onload_self=True, + ) + _apply_group_offloading_hook(submodule, group, None) + modules_with_group_offloading.add(name) + + # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass + # of the module is called + module_dict = dict(module.named_modules()) + parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) + buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) + + # Find closest module parent for each parameter and buffer, and attach group hooks + parent_to_parameters = {} + for name, param in parameters: + parent_name = _find_parent_module_in_module_dict(name, module_dict) + if parent_name in parent_to_parameters: + parent_to_parameters[parent_name].append(param) + else: + parent_to_parameters[parent_name] = [param] + + parent_to_buffers = {} + for name, buffer in buffers: + parent_name = _find_parent_module_in_module_dict(name, module_dict) + if parent_name in parent_to_buffers: + parent_to_buffers[parent_name].append(buffer) + else: + parent_to_buffers[parent_name] = [buffer] + + parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys()) + for name in parent_names: + parameters = parent_to_parameters.get(name, []) + buffers = parent_to_buffers.get(name, []) + parent_module = module_dict[name] + assert getattr(parent_module, "_diffusers_hook", None) is None + group = ModuleGroup( + modules=[], + offload_device=offload_device, + onload_device=onload_device, + offload_leader=parent_module, + onload_leader=parent_module, + parameters=parameters, + buffers=buffers, + non_blocking=non_blocking, + stream=stream, + cpu_param_dict=cpu_param_dict, + onload_self=True, + ) + _apply_group_offloading_hook(parent_module, group, None) + + if stream is not None: + # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer + # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the + # execution order and apply prefetching in the correct order. + unmatched_group = ModuleGroup( + modules=[], + offload_device=offload_device, + onload_device=onload_device, + offload_leader=module, + onload_leader=module, + parameters=None, + buffers=None, + non_blocking=False, + stream=None, + cpu_param_dict=None, + onload_self=True, + ) + _apply_lazy_group_offloading_hook(module, unmatched_group, None) + + +def _apply_group_offloading_hook( + module: torch.nn.Module, + group: ModuleGroup, + next_group: Optional[ModuleGroup] = None, +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(module) + + # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent + # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. + if registry.get_hook(_GROUP_OFFLOADING) is None: + hook = GroupOffloadingHook(group, next_group) + registry.register_hook(hook, _GROUP_OFFLOADING) + + +def _apply_lazy_group_offloading_hook( + module: torch.nn.Module, + group: ModuleGroup, + next_group: Optional[ModuleGroup] = None, +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(module) + + # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent + # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. + if registry.get_hook(_GROUP_OFFLOADING) is None: + hook = GroupOffloadingHook(group, next_group) + registry.register_hook(hook, _GROUP_OFFLOADING) + + lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() + registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) + + +def _gather_parameters_with_no_group_offloading_parent( + module: torch.nn.Module, modules_with_group_offloading: Set[str] +) -> List[torch.nn.Parameter]: + parameters = [] + for name, parameter in module.named_parameters(): + has_parent_with_group_offloading = False + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + if parent_name in modules_with_group_offloading: + has_parent_with_group_offloading = True + break + atoms.pop() + if not has_parent_with_group_offloading: + parameters.append((name, parameter)) + return parameters + + +def _gather_buffers_with_no_group_offloading_parent( + module: torch.nn.Module, modules_with_group_offloading: Set[str] +) -> List[torch.Tensor]: + buffers = [] + for name, buffer in module.named_buffers(): + has_parent_with_group_offloading = False + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + if parent_name in modules_with_group_offloading: + has_parent_with_group_offloading = True + break + atoms.pop() + if not has_parent_with_group_offloading: + buffers.append((name, buffer)) + return buffers + + +def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str: + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + if parent_name in module_dict: + return parent_name + atoms.pop() + return "" + + +def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None: + if not is_accelerate_available(): + return + for name, submodule in module.named_modules(): + if not hasattr(submodule, "_hf_hook"): + continue + if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)): + raise ValueError( + f"Cannot apply group offloading to a module that is already applying an alternative " + f"offloading strategy from Accelerate. If you want to apply group offloading, please " + f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})" + ) + + +def _is_group_offload_enabled(module: torch.nn.Module) -> bool: + for submodule in module.modules(): + if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: + return True + return False + + +def _get_group_onload_device(module: torch.nn.Module) -> torch.device: + for submodule in module.modules(): + if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: + return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device + raise ValueError("Group offloading is not enabled for the provided module.") diff --git a/src/diffusers/models/autoencoders/autoencoder_oobleck.py b/src/diffusers/models/autoencoders/autoencoder_oobleck.py index e8e372a709d7..a8c2a2fd3840 100644 --- a/src/diffusers/models/autoencoders/autoencoder_oobleck.py +++ b/src/diffusers/models/autoencoders/autoencoder_oobleck.py @@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = False + _supports_group_offloading = False @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py index 4759b9141242..a0b3309dc522 100644 --- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py +++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ``` """ + _supports_group_offloading = False + @register_to_config def __init__( self, diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py index e754e134b35f..84215389bf6a 100644 --- a/src/diffusers/models/autoencoders/vq_model.py +++ b/src/diffusers/models/autoencoders/vq_model.py @@ -72,6 +72,7 @@ class VQModel(ModelMixin, ConfigMixin): """ _skip_layerwise_casting_patterns = ["quantize"] + _supports_group_offloading = False @register_to_config def __init__( diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index eb3063ff0c30..67874f75c8f9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -34,7 +34,7 @@ from typing_extensions import Self from .. import __version__ -from ..hooks import apply_layerwise_casting +from ..hooks import apply_group_offloading, apply_layerwise_casting from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -87,7 +87,17 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device: + from ..hooks.group_offloading import _get_group_onload_device + + try: + # Try to get the onload device from the group offloading hook + return _get_group_onload_device(parameter) + except ValueError: + pass + try: + # If the onload device is not available due to no group offloading hooks, try to get the device + # from the first parameter or buffer parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) return next(parameters_and_buffers).device except StopIteration: @@ -166,6 +176,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _no_split_modules = None _keep_in_fp32_modules = None _skip_layerwise_casting_patterns = None + _supports_group_offloading = True def __init__(self): super().__init__() @@ -437,6 +448,55 @@ def enable_layerwise_casting( self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking ) + def enable_group_offload( + self, + onload_device: torch.device, + offload_device: torch.device = torch.device("cpu"), + offload_type: str = "block_level", + num_blocks_per_group: Optional[int] = None, + non_blocking: bool = False, + use_stream: bool = False, + ) -> None: + r""" + Activates group offloading for the current model. + + See [`~hooks.group_offloading.apply_group_offloading`] for more information. + + Example: + + ```python + >>> from diffusers import CogVideoXTransformer3DModel + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained( + ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + + >>> transformer.enable_group_offload( + ... onload_device=torch.device("cuda"), + ... offload_device=torch.device("cpu"), + ... offload_type="leaf_level", + ... use_stream=True, + ... ) + ``` + """ + if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream: + msg = ( + "Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first " + "forward pass is executed with tiling enabled. Please make sure to either:\n" + "1. Run a forward pass with small input shapes.\n" + "2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)." + ) + logger.warning(msg) + if not self._supports_group_offloading: + raise ValueError( + f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute " + f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please " + f"open an issue at https://github.com/huggingface/diffusers/issues." + ) + apply_group_offloading( + self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream + ) + def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -1170,6 +1230,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Adapted from `transformers`. @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): + from ..hooks.group_offloading import _is_group_offload_enabled + # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: if getattr(self, "is_loaded_in_8bit", False): @@ -1182,13 +1244,34 @@ def cuda(self, *args, **kwargs): "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." ) + + # Checks if group offloading is enabled + if _is_group_offload_enabled(self): + logger.warning( + f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported." + ) + return self + return super().cuda(*args, **kwargs) # Adapted from `transformers`. @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): + from ..hooks.group_offloading import _is_group_offload_enabled + + device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs dtype_present_in_args = "dtype" in kwargs + # Try converting arguments to torch.device in case they are passed as strings + for arg in args: + if not isinstance(arg, str): + continue + try: + torch.device(arg) + device_arg_or_kwarg_present = True + except RuntimeError: + pass + if not dtype_present_in_args: for arg in args: if isinstance(arg, torch.dtype): @@ -1213,6 +1296,13 @@ def to(self, *args, **kwargs): "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." ) + + if _is_group_offload_enabled(self) and device_arg_or_kwarg_present: + logger.warning( + f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported." + ) + return self + return super().to(*args, **kwargs) # Taken from `transformers`. diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 6e83f49db71c..cdc0738050e4 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -66,6 +66,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _supports_gradient_checkpointing = True + _supports_group_offloading = False @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 13aa7d076d03..5608a0f605a6 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -245,6 +245,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): """ _skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"] + _supports_group_offloading = False @register_to_config def __init__( diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2fde0bb9f861..2a84af64f8e2 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -394,6 +394,7 @@ def to(self, *args, **kwargs): ) device = device or device_arg + device_type = torch.device(device).type if device is not None else None pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items()) # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. @@ -424,7 +425,7 @@ def module_is_offloaded(module): "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline." ) - if device and torch.device(device).type == "cuda": + if device_type == "cuda": if pipeline_is_sequentially_offloaded and not pipeline_has_bnb: raise ValueError( "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." @@ -437,7 +438,7 @@ def module_is_offloaded(module): # Display a warning in this case (the operation succeeds but the benefits are lost) pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) - if pipeline_is_offloaded and device and torch.device(device).type == "cuda": + if pipeline_is_offloaded and device_type == "cuda": logger.warning( f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." ) @@ -449,6 +450,7 @@ def module_is_offloaded(module): is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) + is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module) if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: logger.warning( @@ -460,11 +462,21 @@ def module_is_offloaded(module): f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." ) + # Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling + # components can be from outside diffusers too, but still have group offloading enabled. + if ( + self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module) + and device is not None + ): + logger.warning( + f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported." + ) + # This can happen for `transformer` models. CPU placement was added in # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): module.to(device=device) - elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb: + elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded: module.to(device, dtype) if ( @@ -1023,6 +1035,19 @@ def _execution_device(self): [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from Accelerate's module hooks. """ + from ..hooks.group_offloading import _get_group_onload_device + + # When apply group offloading at the leaf_level, we're in the same situation as accelerate's sequential + # offloading. We need to return the onload device of the group offloading hooks so that the intermediates + # required for computation (latents, prompt embeddings, etc.) can be created on the correct device. + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module): + continue + try: + return _get_group_onload_device(model) + except ValueError: + pass + for name, model in self.components.items(): if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: continue @@ -1061,6 +1086,8 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ + self._maybe_raise_error_if_group_offload_active(raise_error=True) + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( @@ -1172,6 +1199,8 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ + self._maybe_raise_error_if_group_offload_active(raise_error=True) + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: @@ -1896,6 +1925,24 @@ def from_pipe(cls, pipeline, **kwargs): return new_pipeline + def _maybe_raise_error_if_group_offload_active( + self, raise_error: bool = False, module: Optional[torch.nn.Module] = None + ) -> bool: + from ..hooks.group_offloading import _is_group_offload_enabled + + components = self.components.values() if module is None else [module] + components = [component for component in components if isinstance(component, torch.nn.Module)] + for component in components: + if _is_group_offload_enabled(component): + if raise_error: + raise ValueError( + "You are trying to apply model/sequential CPU offloading to a pipeline that contains components " + "with group offloading enabled. This is not supported. Please disable group offloading for " + "components of the pipeline to use other offloading methods." + ) + return True + return False + class StableDiffusionMixin: r""" diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py new file mode 100644 index 000000000000..d8f41fc2b1ae --- /dev/null +++ b/tests/hooks/test_group_offloading.py @@ -0,0 +1,214 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers.models import ModelMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import get_logger +from diffusers.utils.testing_utils import require_torch_gpu, torch_device + + +class DummyBlock(torch.nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.proj_in = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.proj_out = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj_in(x) + x = self.activation(x) + x = self.proj_out(x) + return x + + +class DummyModel(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: + super().__init__() + + self.linear_1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.blocks = torch.nn.ModuleList( + [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] + ) + self.linear_2 = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x) + x = self.activation(x) + for block in self.blocks: + x = block(x) + x = self.linear_2(x) + return x + + +class DummyPipeline(DiffusionPipeline): + model_cpu_offload_seq = "model" + + def __init__(self, model: torch.nn.Module) -> None: + super().__init__() + + self.register_modules(model=model) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + for _ in range(2): + x = x + 0.1 * self.model(x) + return x + + +@require_torch_gpu +class GroupOffloadTests(unittest.TestCase): + in_features = 64 + hidden_features = 256 + out_features = 64 + num_layers = 4 + + def setUp(self): + with torch.no_grad(): + self.model = self.get_model() + self.input = torch.randn((4, self.in_features)).to(torch_device) + + def tearDown(self): + super().tearDown() + + del self.model + del self.input + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + def get_model(self): + torch.manual_seed(0) + return DummyModel( + in_features=self.in_features, + hidden_features=self.hidden_features, + out_features=self.out_features, + num_layers=self.num_layers, + ) + + def test_offloading_forward_pass(self): + @torch.no_grad() + def run_forward(model): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + self.assertTrue( + all( + module._diffusers_hook.get_hook("group_offloading") is not None + for module in model.modules() + if hasattr(module, "_diffusers_hook") + ) + ) + model.eval() + output = model(self.input)[0].cpu() + max_memory_allocated = torch.cuda.max_memory_allocated() + return output, max_memory_allocated + + self.model.to(torch_device) + output_without_group_offloading, mem_baseline = run_forward(self.model) + self.model.to("cpu") + + model = self.get_model() + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + output_with_group_offloading1, mem1 = run_forward(model) + + model = self.get_model() + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1) + output_with_group_offloading2, mem2 = run_forward(model) + + model = self.get_model() + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + output_with_group_offloading3, mem3 = run_forward(model) + + model = self.get_model() + model.enable_group_offload(torch_device, offload_type="leaf_level") + output_with_group_offloading4, mem4 = run_forward(model) + + model = self.get_model() + model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True) + output_with_group_offloading5, mem5 = run_forward(model) + + # Precision assertions - offloading should not impact the output + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5)) + + # Memory assertions - offloading should reduce memory usage + self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline) + + def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self): + if torch.device(torch_device).type != "cuda": + return + self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + logger = get_logger("diffusers.models.modeling_utils") + logger.setLevel("INFO") + with self.assertLogs(logger, level="WARNING") as cm: + self.model.to(torch_device) + self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0]) + + def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self): + if torch.device(torch_device).type != "cuda": + return + pipe = DummyPipeline(self.model) + self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + logger = get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel("INFO") + with self.assertLogs(logger, level="WARNING") as cm: + pipe.to(torch_device) + self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0]) + + def test_error_raised_if_streams_used_and_no_cuda_device(self): + original_is_available = torch.cuda.is_available + torch.cuda.is_available = lambda: False + with self.assertRaises(ValueError): + self.model.enable_group_offload( + onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True + ) + torch.cuda.is_available = original_is_available + + def test_error_raised_if_supports_group_offloading_false(self): + self.model._supports_group_offloading = False + with self.assertRaisesRegex(ValueError, "does not support group offloading"): + self.model.enable_group_offload(onload_device=torch.device("cuda")) + + def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self): + pipe = DummyPipeline(self.model) + pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"): + pipe.enable_model_cpu_offload() + + def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self): + pipe = DummyPipeline(self.model) + pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"): + pipe.enable_sequential_cpu_offload() + + def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self): + pipe = DummyPipeline(self.model) + pipe.enable_model_cpu_offload() + with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"): + pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + + def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self): + pipe = DummyPipeline(self.model) + pipe.enable_sequential_cpu_offload() + with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"): + pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index e083d2777a7e..b633c16aaec5 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1458,6 +1458,55 @@ def get_memory_usage(storage_dtype, compute_dtype): or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE ) + @require_torch_gpu + def test_group_offloading(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + torch.manual_seed(0) + + @torch.no_grad() + def run_forward(model): + self.assertTrue( + all( + module._diffusers_hook.get_hook("group_offloading") is not None + for module in model.modules() + if hasattr(module, "_diffusers_hook") + ) + ) + model.eval() + return model(**inputs_dict)[0] + + model = self.model_class(**init_dict) + if not getattr(model, "_supports_group_offloading", True): + return + + model.to(torch_device) + output_without_group_offloading = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1) + output_with_group_offloading1 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True) + output_with_group_offloading2 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload(torch_device, offload_type="leaf_level") + output_with_group_offloading3 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True) + output_with_group_offloading4 = run_forward(model) + + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)) + @is_staging_test class ModelPushToHubTester(unittest.TestCase): diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index 2a4d0a36dffa..30fdd68cfd36 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -58,6 +58,7 @@ class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTes ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py index 2dfc36a6ce45..a0fbc5df1c28 100644 --- a/tests/pipelines/amused/test_amused.py +++ b/tests/pipelines/amused/test_amused.py @@ -39,6 +39,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 1b3115c8eb1d..4913a46b8d4f 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -61,6 +61,7 @@ class AnimateDiffPipelineFastTests( ] ) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): cross_attention_dim = 8 diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py index bee905f9ae13..f0b67afcc052 100644 --- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py +++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py @@ -31,6 +31,7 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ) batch_params = frozenset(["prompt", "negative_prompt"]) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 750f20f8fbe5..c09b00e1d16b 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -60,6 +60,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastT ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py index c936bad4c3d5..2e962bd247b9 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py +++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py @@ -56,6 +56,7 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py index 102a5c66e624..4619de81d535 100644 --- a/tests/pipelines/cogview3/test_cogview3plus.py +++ b/tests/pipelines/cogview3/test_cogview3plus.py @@ -57,6 +57,7 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/consisid/test_consisid.py b/tests/pipelines/consisid/test_consisid.py index f949cfb2d36d..a39c17bb4f79 100644 --- a/tests/pipelines/consisid/test_consisid.py +++ b/tests/pipelines/consisid/test_consisid.py @@ -59,6 +59,7 @@ class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index e0fc00171031..e2c0c60ddfa4 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -127,6 +127,7 @@ class ControlNetPipelineFastTests( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index e75fe8903134..dda6339427f8 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -76,6 +76,7 @@ class StableDiffusionXLControlNetPipelineFastTests( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index 8b9852dbec6e..cce14342699c 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -51,6 +51,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin): params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index e1894d555c3c..04daca27c3dd 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -60,6 +60,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ) batch_params = frozenset(["prompt", "negative_prompt"]) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components( self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm", use_dual_attention=False diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 4c184db99630..1da5b52bd050 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -140,6 +140,7 @@ class ControlNetXSPipelineFastTests( test_attention_slicing = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index 7537efe0bbf9..644bb669d8e8 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -79,6 +79,7 @@ class StableDiffusionXLControlNetXSPipelineFastTests( test_attention_slicing = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index bab343a5954c..2382f453bb39 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -35,6 +35,7 @@ class FluxPipelineFastTests( # there is no xformers processor for Flux test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py index 7fdb19327213..5bb7cdec034c 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control.py +++ b/tests/pipelines/flux/test_pipeline_flux_control.py @@ -23,6 +23,7 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): # there is no xformers processor for Flux test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux_fill.py b/tests/pipelines/flux/test_pipeline_flux_fill.py index 620ecb8a831f..1d488db71ced 100644 --- a/tests/pipelines/flux/test_pipeline_flux_fill.py +++ b/tests/pipelines/flux/test_pipeline_flux_fill.py @@ -24,6 +24,7 @@ class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin): batch_params = frozenset(["prompt"]) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index ba7ec43ec977..dd0f6437df87 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -54,6 +54,7 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca # there is no xformers processor for Flux test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 64459a659179..315da3ed46ea 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -54,6 +54,7 @@ class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTeste required_optional_params = PipelineTesterMixin.required_optional_params test_layerwise_casting = True + test_group_offloading = True pab_config = PyramidAttentionBroadcastConfig( spatial_attention_block_skip_range=2, diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index 64b366ea8ad6..4f72729fc9ce 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -47,6 +47,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py index 7c1923313b23..18dcdef98d7d 100644 --- a/tests/pipelines/lumina/test_lumina_nextdit.py +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -33,6 +33,7 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM supports_dduf = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index b7bb844ff311..ed41e82aca9f 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -56,6 +56,7 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index 747be38d495c..ead6c2b208de 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -56,6 +56,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr ] ) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): cross_attention_dim = 8 diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index 7df6656f6f87..ae0f9b50f74e 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -51,6 +51,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = PipelineTesterMixin.required_optional_params test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py index 6e265b9d5eb8..9bfeb691d770 100644 --- a/tests/pipelines/pixart_sigma/test_pixart.py +++ b/tests/pipelines/pixart_sigma/test_pixart.py @@ -56,6 +56,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = PipelineTesterMixin.required_optional_params test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py index f70f9d91f19c..34df808d3320 100644 --- a/tests/pipelines/sana/test_sana.py +++ b/tests/pipelines/sana/test_sana.py @@ -53,6 +53,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 1e700bed03f8..d60092c4e5cb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -124,6 +124,7 @@ class StableDiffusionPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, time_cond_proj_dim=None): cross_attention_dim = 8 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 10b8a1818a29..a7375d37eccd 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -76,6 +76,7 @@ class StableDiffusion2PipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index df37090eeba2..24d03a035066 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -36,6 +36,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): ) batch_params = frozenset(["prompt", "negative_prompt"]) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index f1422022a7aa..dfd1c9c37271 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -76,6 +76,7 @@ class StableDiffusionXLPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index de5faa185c2f..355e851f9fdd 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -29,6 +29,7 @@ StableDiffusionXLPipeline, UNet2DConditionModel, ) +from diffusers.hooks import apply_group_offloading from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin @@ -47,6 +48,7 @@ require_accelerator, require_hf_hub_version_greater, require_torch, + require_torch_gpu, require_transformers_version_greater, skip_mps, torch_device, @@ -990,6 +992,7 @@ class PipelineTesterMixin: test_xformers_attention = True test_layerwise_casting = False + test_group_offloading = False supports_dduf = True def get_generator(self, seed): @@ -2044,6 +2047,79 @@ def test_layerwise_casting_inference(self): inputs = self.get_dummy_inputs(torch_device) _ = pipe(**inputs)[0] + @require_torch_gpu + def test_group_offloading_inference(self): + if not self.test_group_offloading: + return + + def create_pipe(): + torch.manual_seed(0) + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + return pipe + + def enable_group_offload_on_component(pipe, group_offloading_kwargs): + # We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If + # tiling is enabled and a forward pass is run, when cuda streams are used, the execution order of + # the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a + # warmup forward pass (even with dummy small inputs) is recommended. + for component_name in [ + "text_encoder", + "text_encoder_2", + "text_encoder_3", + "transformer", + "unet", + "controlnet", + ]: + if not hasattr(pipe, component_name): + continue + component = getattr(pipe, component_name) + if not getattr(component, "_supports_group_offloading", True): + continue + if hasattr(component, "enable_group_offload"): + # For diffusers ModelMixin implementations + component.enable_group_offload(torch.device(torch_device), **group_offloading_kwargs) + else: + # For other models not part of diffusers + apply_group_offloading( + component, onload_device=torch.device(torch_device), **group_offloading_kwargs + ) + self.assertTrue( + all( + module._diffusers_hook.get_hook("group_offloading") is not None + for module in component.modules() + if hasattr(module, "_diffusers_hook") + ) + ) + for component_name in ["vae", "vqvae"]: + if hasattr(pipe, component_name): + getattr(pipe, component_name).to(torch_device) + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(torch_device) + return pipe(**inputs)[0] + + pipe = create_pipe().to(torch_device) + output_without_group_offloading = run_forward(pipe) + + pipe = create_pipe() + enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1}) + output_with_group_offloading1 = run_forward(pipe) + + pipe = create_pipe() + enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"}) + output_with_group_offloading2 = run_forward(pipe) + + if torch.is_tensor(output_without_group_offloading): + output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy() + output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy() + output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy() + + self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4)) + self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4)) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase):