From 7a537f61e2bbe2ad7d82cb28117491af9d999645 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 17 Mar 2025 23:53:59 +0100 Subject: [PATCH 1/2] update --- src/diffusers/hooks/group_offloading.py | 30 +++++++++++++++++++------ 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index c389c5dc9826..286fd941ff73 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -83,7 +83,10 @@ def onload_(self): with context: for group_module in self.modules: - group_module.to(self.onload_device, non_blocking=self.non_blocking) + for param in group_module.parameters(): + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + for buffer in group_module.buffers(): + buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) if self.parameters is not None: for param in self.parameters: param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) @@ -98,6 +101,12 @@ def offload_(self): for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] + if self.parameters is not None: + for param in self.parameters: + param.data = self.cpu_param_dict[param] + if self.buffers is not None: + for buffer in self.buffers: + buffer.data = self.cpu_param_dict[buffer] else: for group_module in self.modules: group_module.to(self.offload_device, non_blocking=self.non_blocking) @@ -387,9 +396,7 @@ def _apply_group_offloading_block_level( # Create a pinned CPU parameter dict for async data transfer if streams are to be used cpu_param_dict = None if stream is not None: - for param in module.parameters(): - param.data = param.data.cpu().pin_memory() - cpu_param_dict = {param: param.data for param in module.parameters()} + cpu_param_dict = _get_pinned_cpu_param_dict(module) # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() @@ -486,9 +493,7 @@ def _apply_group_offloading_leaf_level( # Create a pinned CPU parameter dict for async data transfer if streams are to be used cpu_param_dict = None if stream is not None: - for param in module.parameters(): - param.data = param.data.cpu().pin_memory() - cpu_param_dict = {param: param.data for param in module.parameters()} + cpu_param_dict = _get_pinned_cpu_param_dict(module) # Create module groups for leaf modules and apply group offloading hooks modules_with_group_offloading = set() @@ -604,6 +609,17 @@ def _apply_lazy_group_offloading_hook( registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) +def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]: + cpu_param_dict = {} + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict[param] = param.data + for buffer in module.buffers(): + buffer.data = buffer.data.cpu().pin_memory() + cpu_param_dict[buffer] = buffer.data + return cpu_param_dict + + def _gather_parameters_with_no_group_offloading_parent( module: torch.nn.Module, modules_with_group_offloading: Set[str] ) -> List[torch.nn.Parameter]: From 8517bfefeba8ec7cbd129445235f8555209bdb2d Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 18 Mar 2025 01:13:13 +0100 Subject: [PATCH 2/2] update --- src/diffusers/hooks/group_offloading.py | 27 ++++++++++++++++--------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 286fd941ff73..e4b9ed9307ea 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -181,6 +181,13 @@ def __init__(self): self._layer_execution_tracker_module_names = set() def initialize_hook(self, module): + def make_execution_order_update_callback(current_name, current_submodule): + def callback(): + logger.debug(f"Adding {current_name} to the execution order") + self.execution_order.append((current_name, current_submodule)) + + return callback + # To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any # of the groups), we add a layer execution tracker hook that will be used to determine the order in which the # layers are executed during the forward pass. @@ -192,14 +199,8 @@ def initialize_hook(self, module): group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING) if group_offloading_hook is not None: - - def make_execution_order_update_callback(current_name, current_submodule): - def callback(): - logger.debug(f"Adding {current_name} to the execution order") - self.execution_order.append((current_name, current_submodule)) - - return callback - + # For the first forward pass, we have to load in a blocking manner + group_offloading_hook.group.non_blocking = False layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule)) registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER) self._layer_execution_tracker_module_names.add(name) @@ -229,6 +230,7 @@ def post_forward(self, module, output): # Remove the layer execution tracker hooks from the submodules base_module_registry = module._diffusers_hook registries = [submodule._diffusers_hook for _, submodule in self.execution_order] + group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] for i in range(num_executed): registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False) @@ -236,8 +238,13 @@ def post_forward(self, module, output): # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False) - # Apply lazy prefetching by setting required attributes - group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] + # LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True. + # We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to + # see the benefits of prefetching. + for hook in group_offloading_hooks: + hook.group.non_blocking = True + + # Set required attributes for prefetching if num_executed > 0: base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING) base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group