Skip to content

Commit 813d42c

Browse files
authored
Group offloading improvements (#11094)
update
1 parent b4d7e9c commit 813d42c

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def onload_(self):
8383

8484
with context:
8585
for group_module in self.modules:
86-
group_module.to(self.onload_device, non_blocking=self.non_blocking)
86+
for param in group_module.parameters():
87+
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
88+
for buffer in group_module.buffers():
89+
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
8790
if self.parameters is not None:
8891
for param in self.parameters:
8992
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
@@ -98,6 +101,12 @@ def offload_(self):
98101
for group_module in self.modules:
99102
for param in group_module.parameters():
100103
param.data = self.cpu_param_dict[param]
104+
if self.parameters is not None:
105+
for param in self.parameters:
106+
param.data = self.cpu_param_dict[param]
107+
if self.buffers is not None:
108+
for buffer in self.buffers:
109+
buffer.data = self.cpu_param_dict[buffer]
101110
else:
102111
for group_module in self.modules:
103112
group_module.to(self.offload_device, non_blocking=self.non_blocking)
@@ -387,9 +396,7 @@ def _apply_group_offloading_block_level(
387396
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
388397
cpu_param_dict = None
389398
if stream is not None:
390-
for param in module.parameters():
391-
param.data = param.data.cpu().pin_memory()
392-
cpu_param_dict = {param: param.data for param in module.parameters()}
399+
cpu_param_dict = _get_pinned_cpu_param_dict(module)
393400

394401
# Create module groups for ModuleList and Sequential blocks
395402
modules_with_group_offloading = set()
@@ -486,9 +493,7 @@ def _apply_group_offloading_leaf_level(
486493
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
487494
cpu_param_dict = None
488495
if stream is not None:
489-
for param in module.parameters():
490-
param.data = param.data.cpu().pin_memory()
491-
cpu_param_dict = {param: param.data for param in module.parameters()}
496+
cpu_param_dict = _get_pinned_cpu_param_dict(module)
492497

493498
# Create module groups for leaf modules and apply group offloading hooks
494499
modules_with_group_offloading = set()
@@ -604,6 +609,17 @@ def _apply_lazy_group_offloading_hook(
604609
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
605610

606611

612+
def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
613+
cpu_param_dict = {}
614+
for param in module.parameters():
615+
param.data = param.data.cpu().pin_memory()
616+
cpu_param_dict[param] = param.data
617+
for buffer in module.buffers():
618+
buffer.data = buffer.data.cpu().pin_memory()
619+
cpu_param_dict[buffer] = buffer.data
620+
return cpu_param_dict
621+
622+
607623
def _gather_parameters_with_no_group_offloading_parent(
608624
module: torch.nn.Module, modules_with_group_offloading: Set[str]
609625
) -> List[torch.nn.Parameter]:

0 commit comments

Comments
 (0)